diff --git a/changelog.d/18828.feature b/changelog.d/18828.feature new file mode 100644 index 0000000000..e7f3541de4 --- /dev/null +++ b/changelog.d/18828.feature @@ -0,0 +1 @@ +Cleanly shutdown `SynapseHomeServer` object. diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index e170aabdae..0b854cdba5 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -68,18 +68,42 @@ PROMETHEUS_METRIC_MISSING_FROM_LIST_TO_CHECK = ErrorCode( category="per-homeserver-tenant-metrics", ) +PREFER_SYNAPSE_CLOCK_CALL_LATER = ErrorCode( + "call-later-not-tracked", + "Prefer using `synapse.util.Clock.call_later` instead of `reactor.callLater`", + category="synapse-reactor-clock", +) + +PREFER_SYNAPSE_CLOCK_LOOPING_CALL = ErrorCode( + "prefer-synapse-clock-looping-call", + "Prefer using `synapse.util.Clock.looping_call` instead of `task.LoopingCall`", + category="synapse-reactor-clock", +) + PREFER_SYNAPSE_CLOCK_CALL_WHEN_RUNNING = ErrorCode( "prefer-synapse-clock-call-when-running", - "`synapse.util.Clock.call_when_running` should be used instead of `reactor.callWhenRunning`", + "Prefer using `synapse.util.Clock.call_when_running` instead of `reactor.callWhenRunning`", category="synapse-reactor-clock", ) PREFER_SYNAPSE_CLOCK_ADD_SYSTEM_EVENT_TRIGGER = ErrorCode( "prefer-synapse-clock-add-system-event-trigger", - "`synapse.util.Clock.add_system_event_trigger` should be used instead of `reactor.addSystemEventTrigger`", + "Prefer using `synapse.util.Clock.add_system_event_trigger` instead of `reactor.addSystemEventTrigger`", category="synapse-reactor-clock", ) +MULTIPLE_INTERNAL_CLOCKS_CREATED = ErrorCode( + "multiple-internal-clocks", + "Only one instance of `clock.Clock` should be created", + category="synapse-reactor-clock", +) + +UNTRACKED_BACKGROUND_PROCESS = ErrorCode( + "untracked-background-process", + "Prefer using `HomeServer.run_as_background_process` method over the bare `run_as_background_process`", + category="synapse-tracked-calls", +) + class Sentinel(enum.Enum): # defining a sentinel in this way allows mypy to correctly handle the @@ -222,6 +246,18 @@ class SynapsePlugin(Plugin): # callback, let's just pass it in while we have it. return lambda ctx: check_prometheus_metric_instantiation(ctx, fullname) + if fullname == "twisted.internet.task.LoopingCall": + return check_looping_call + + if fullname == "synapse.util.clock.Clock": + return check_clock_creation + + if ( + fullname + == "synapse.metrics.background_process_metrics.run_as_background_process" + ): + return check_background_process + return None def get_method_signature_hook( @@ -241,6 +277,13 @@ class SynapsePlugin(Plugin): ): return check_is_cacheable_wrapper + if fullname in ( + "twisted.internet.interfaces.IReactorTime.callLater", + "synapse.types.ISynapseThreadlessReactor.callLater", + "synapse.types.ISynapseReactor.callLater", + ): + return check_call_later + if fullname in ( "twisted.internet.interfaces.IReactorCore.callWhenRunning", "synapse.types.ISynapseThreadlessReactor.callWhenRunning", @@ -258,6 +301,78 @@ class SynapsePlugin(Plugin): return None +def check_clock_creation(ctx: FunctionSigContext) -> CallableType: + """ + Ensure that the only `clock.Clock` instance is the one used by the `HomeServer`. + This is so that the `HomeServer` can cancel any tracked delayed or looping calls + during server shutdown. + + Args: + ctx: The `FunctionSigContext` from mypy. + """ + signature: CallableType = ctx.default_signature + ctx.api.fail( + "Expected the only `clock.Clock` instance to be the one used by the `HomeServer`. " + "This is so that the `HomeServer` can cancel any tracked delayed or looping calls " + "during server shutdown", + ctx.context, + code=MULTIPLE_INTERNAL_CLOCKS_CREATED, + ) + + return signature + + +def check_call_later(ctx: MethodSigContext) -> CallableType: + """ + Ensure that the `reactor.callLater` callsites aren't used. + + `synapse.util.Clock.call_later` should always be used instead of `reactor.callLater`. + This is because the `synapse.util.Clock` tracks delayed calls in order to cancel any + outstanding calls during server shutdown. Delayed calls which are either short lived + (<~60s) or frequently called and can be tracked via other means could be candidates for + using `synapse.util.Clock.call_later` with `call_later_cancel_on_shutdown` set to + `False`. There shouldn't be a need to use `reactor.callLater` outside of tests or the + `Clock` class itself. If a need arises, you can use a type ignore comment to disable the + check, e.g. `# type: ignore[call-later-not-tracked]`. + + Args: + ctx: The `FunctionSigContext` from mypy. + """ + signature: CallableType = ctx.default_signature + ctx.api.fail( + "Expected all `reactor.callLater` calls to use `synapse.util.Clock.call_later` " + "instead. This is so that long lived calls can be tracked for cancellation during " + "server shutdown", + ctx.context, + code=PREFER_SYNAPSE_CLOCK_CALL_LATER, + ) + + return signature + + +def check_looping_call(ctx: FunctionSigContext) -> CallableType: + """ + Ensure that the `task.LoopingCall` callsites aren't used. + + `synapse.util.Clock.looping_call` should always be used instead of `task.LoopingCall`. + `synapse.util.Clock` tracks looping calls in order to cancel any outstanding calls + during server shutdown. + + Args: + ctx: The `FunctionSigContext` from mypy. + """ + signature: CallableType = ctx.default_signature + ctx.api.fail( + "Expected all `task.LoopingCall` instances to use `synapse.util.Clock.looping_call` " + "instead. This is so that long lived calls can be tracked for cancellation during " + "server shutdown", + ctx.context, + code=PREFER_SYNAPSE_CLOCK_LOOPING_CALL, + ) + + return signature + + def check_call_when_running(ctx: MethodSigContext) -> CallableType: """ Ensure that the `reactor.callWhenRunning` callsites aren't used. @@ -312,6 +427,27 @@ def check_add_system_event_trigger(ctx: MethodSigContext) -> CallableType: return signature +def check_background_process(ctx: FunctionSigContext) -> CallableType: + """ + Ensure that calls to `run_as_background_process` use the `HomeServer` method. + This is so that the `HomeServer` can cancel any running background processes during + server shutdown. + + Args: + ctx: The `FunctionSigContext` from mypy. + """ + signature: CallableType = ctx.default_signature + ctx.api.fail( + "Prefer using `HomeServer.run_as_background_process` method over the bare " + "`run_as_background_process`. This is so that the `HomeServer` can cancel " + "any background processes during server shutdown", + ctx.context, + code=UNTRACKED_BACKGROUND_PROCESS, + ) + + return signature + + def analyze_prometheus_metric_classes(ctx: ClassDefContext) -> None: """ Cross-check the list of Prometheus metric classes against the diff --git a/synapse/_scripts/generate_workers_map.py b/synapse/_scripts/generate_workers_map.py index 8878e364e2..f66c01040c 100755 --- a/synapse/_scripts/generate_workers_map.py +++ b/synapse/_scripts/generate_workers_map.py @@ -157,7 +157,12 @@ def get_registered_paths_for_default( # TODO We only do this to avoid an error, but don't need the database etc hs.setup() registered_paths = get_registered_paths_for_hs(hs) - hs.cleanup() + # NOTE: a more robust implementation would properly shutdown/cleanup each server + # to avoid resource buildup. + # However, the call to `shutdown` is `async` so it would require additional complexity here. + # We are intentionally skipping this cleanup because this is a short-lived, one-off + # utility script where the simpler approach is sufficient and we shouldn't run into + # any resource buildup issues. return registered_paths diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py index caaecda161..ad02f0ed88 100644 --- a/synapse/_scripts/update_synapse_database.py +++ b/synapse/_scripts/update_synapse_database.py @@ -28,7 +28,6 @@ import yaml from twisted.internet import defer, reactor as reactor_ from synapse.config.homeserver import HomeServerConfig -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.server import HomeServer from synapse.storage import DataStore from synapse.types import ISynapseReactor @@ -53,7 +52,6 @@ class MockHomeserver(HomeServer): def run_background_updates(hs: HomeServer) -> None: - server_name = hs.hostname main = hs.get_datastores().main state = hs.get_datastores().state @@ -67,9 +65,8 @@ def run_background_updates(hs: HomeServer) -> None: def run() -> None: # Apply all background updates on the database. defer.ensureDeferred( - run_as_background_process( + hs.run_as_background_process( "background_updates", - server_name, run_background_updates, ) ) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 5638724896..655f684ecf 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -28,6 +28,7 @@ import sys import traceback import warnings from textwrap import indent +from threading import Thread from typing import ( TYPE_CHECKING, Any, @@ -40,6 +41,7 @@ from typing import ( Tuple, cast, ) +from wsgiref.simple_server import WSGIServer from cryptography.utils import CryptographyDeprecationWarning from typing_extensions import ParamSpec @@ -97,22 +99,47 @@ reactor = cast(ISynapseReactor, _reactor) logger = logging.getLogger(__name__) -# list of tuples of function, args list, kwargs dict -_sighup_callbacks: List[ - Tuple[Callable[..., None], Tuple[object, ...], Dict[str, object]] -] = [] +_instance_id_to_sighup_callbacks_map: Dict[ + str, List[Tuple[Callable[..., None], Tuple[object, ...], Dict[str, object]]] +] = {} +""" +Map from homeserver instance_id to a list of callbacks. + +We use `instance_id` instead of `server_name` because it's possible to have multiple +workers running in the same process with the same `server_name`. +""" P = ParamSpec("P") -def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: +def register_sighup( + homeserver_instance_id: str, + func: Callable[P, None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: """ Register a function to be called when a SIGHUP occurs. Args: + homeserver_instance_id: The unique ID for this Synapse process instance + (`hs.get_instance_id()`) that this hook is associated with. func: Function to be called when sent a SIGHUP signal. *args, **kwargs: args and kwargs to be passed to the target function. """ - _sighup_callbacks.append((func, args, kwargs)) + + _instance_id_to_sighup_callbacks_map.setdefault(homeserver_instance_id, []).append( + (func, args, kwargs) + ) + + +def unregister_sighups(instance_id: str) -> None: + """ + Unregister all sighup functions associated with this Synapse instance. + + Args: + instance_id: Unique ID for this Synapse process instance. + """ + _instance_id_to_sighup_callbacks_map.pop(instance_id, []) def start_worker_reactor( @@ -281,7 +308,9 @@ def register_start( clock.call_when_running(lambda: defer.ensureDeferred(wrapper())) -def listen_metrics(bind_addresses: StrCollection, port: int) -> None: +def listen_metrics( + bind_addresses: StrCollection, port: int +) -> List[Tuple[WSGIServer, Thread]]: """ Start Prometheus metrics server. @@ -294,14 +323,22 @@ def listen_metrics(bind_addresses: StrCollection, port: int) -> None: bytecode at a time), this still works because the metrics thread can preempt the Twisted reactor thread between bytecode boundaries and the metrics thread gets scheduled with roughly equal priority to the Twisted reactor thread. + + Returns: + List of WSGIServer with the thread they are running on. """ from prometheus_client import start_http_server as start_http_server_prometheus from synapse.metrics import RegistryProxy + servers: List[Tuple[WSGIServer, Thread]] = [] for host in bind_addresses: logger.info("Starting metrics listener on %s:%d", host, port) - start_http_server_prometheus(port, addr=host, registry=RegistryProxy) + server, thread = start_http_server_prometheus( + port, addr=host, registry=RegistryProxy + ) + servers.append((server, thread)) + return servers def listen_manhole( @@ -309,7 +346,7 @@ def listen_manhole( port: int, manhole_settings: ManholeConfig, manhole_globals: dict, -) -> None: +) -> List[Port]: # twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing # warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so # suppress the warning for now. @@ -321,7 +358,7 @@ def listen_manhole( from synapse.util.manhole import manhole - listen_tcp( + return listen_tcp( bind_addresses, port, manhole(settings=manhole_settings, globals=manhole_globals), @@ -498,7 +535,7 @@ def refresh_certificate(hs: "HomeServer") -> None: logger.info("Context factories updated.") -async def start(hs: "HomeServer") -> None: +async def start(hs: "HomeServer", freeze: bool = True) -> None: """ Start a Synapse server or worker. @@ -509,6 +546,11 @@ async def start(hs: "HomeServer") -> None: Args: hs: homeserver instance + freeze: whether to freeze the homeserver base objects in the garbage collector. + May improve garbage collection performance by marking objects with an effectively + static lifetime as frozen so they don't need to be considered for cleanup. + If you ever want to `shutdown` the homeserver, this needs to be + False otherwise the homeserver cannot be garbage collected after `shutdown`. """ server_name = hs.hostname reactor = hs.get_reactor() @@ -541,12 +583,17 @@ async def start(hs: "HomeServer") -> None: # we're not using systemd. sdnotify(b"RELOADING=1") - for i, args, kwargs in _sighup_callbacks: - i(*args, **kwargs) + for sighup_callbacks in _instance_id_to_sighup_callbacks_map.values(): + for func, args, kwargs in sighup_callbacks: + func(*args, **kwargs) sdnotify(b"READY=1") - return run_as_background_process( + # It's okay to ignore the linter error here and call + # `run_as_background_process` directly because `_handle_sighup` operates + # outside of the scope of a specific `HomeServer` instance and holds no + # references to it which would prevent a clean shutdown. + return run_as_background_process( # type: ignore[untracked-background-process] "sighup", server_name, _handle_sighup, @@ -564,8 +611,8 @@ async def start(hs: "HomeServer") -> None: signal.signal(signal.SIGHUP, run_sighup) - register_sighup(refresh_certificate, hs) - register_sighup(reload_cache_config, hs.config) + register_sighup(hs.get_instance_id(), refresh_certificate, hs) + register_sighup(hs.get_instance_id(), reload_cache_config, hs.config) # Apply the cache config. hs.config.caches.resize_all_caches() @@ -603,7 +650,11 @@ async def start(hs: "HomeServer") -> None: logger.info("Shutting down...") # Log when we start the shut down process. - hs.get_clock().add_system_event_trigger("before", "shutdown", log_shutdown) + hs.register_sync_shutdown_handler( + phase="before", + eventType="shutdown", + shutdown_func=log_shutdown, + ) setup_sentry(hs) setup_sdnotify(hs) @@ -632,18 +683,24 @@ async def start(hs: "HomeServer") -> None: # `REQUIRED_ON_BACKGROUND_TASK_STARTUP` start_phone_stats_home(hs) - # We now freeze all allocated objects in the hopes that (almost) - # everything currently allocated are things that will be used for the - # rest of time. Doing so means less work each GC (hopefully). - # - # PyPy does not (yet?) implement gc.freeze() - if hasattr(gc, "freeze"): - gc.collect() - gc.freeze() + if freeze: + # We now freeze all allocated objects in the hopes that (almost) + # everything currently allocated are things that will be used for the + # rest of time. Doing so means less work each GC (hopefully). + # + # Note that freezing the homeserver object means that it won't be able to be + # garbage collected in the case of attempting an in-memory `shutdown`. This only + # needs to be considered if such a case is desirable. Exiting the entire Python + # process will function expectedly either way. + # + # PyPy does not (yet?) implement gc.freeze() + if hasattr(gc, "freeze"): + gc.collect() + gc.freeze() - # Speed up shutdowns by freezing all allocated objects. This moves everything - # into the permanent generation and excludes them from the final GC. - atexit.register(gc.freeze) + # Speed up process exit by freezing all allocated objects. This moves everything + # into the permanent generation and excludes them from the final GC. + atexit.register(gc.freeze) def reload_cache_config(config: HomeServerConfig) -> None: diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 51b8adaa27..7e8b47c20a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -278,11 +278,13 @@ class GenericWorkerServer(HomeServer): self._listen_http(listener) elif listener.type == "manhole": if isinstance(listener, TCPListenerConfig): - _base.listen_manhole( - listener.bind_addresses, - listener.port, - manhole_settings=self.config.server.manhole_settings, - manhole_globals={"hs": self}, + self._listening_services.extend( + _base.listen_manhole( + listener.bind_addresses, + listener.port, + manhole_settings=self.config.server.manhole_settings, + manhole_globals={"hs": self}, + ) ) else: raise ConfigError( @@ -296,9 +298,11 @@ class GenericWorkerServer(HomeServer): ) else: if isinstance(listener, TCPListenerConfig): - _base.listen_metrics( - listener.bind_addresses, - listener.port, + self._metrics_listeners.extend( + _base.listen_metrics( + listener.bind_addresses, + listener.port, + ) ) else: raise ConfigError( diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 35d633d527..3c691906ca 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -22,7 +22,7 @@ import logging import os import sys -from typing import Dict, Iterable, List +from typing import Dict, Iterable, List, Optional from twisted.internet.tcp import Port from twisted.web.resource import EncodingResourceWrapper, Resource @@ -70,6 +70,7 @@ from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer from synapse.storage import DataStore +from synapse.types import ISynapseReactor from synapse.util.check_dependencies import VERSION, check_requirements from synapse.util.httpresourcetree import create_resource_tree from synapse.util.module_loader import load_module @@ -277,11 +278,13 @@ class SynapseHomeServer(HomeServer): ) elif listener.type == "manhole": if isinstance(listener, TCPListenerConfig): - _base.listen_manhole( - listener.bind_addresses, - listener.port, - manhole_settings=self.config.server.manhole_settings, - manhole_globals={"hs": self}, + self._listening_services.extend( + _base.listen_manhole( + listener.bind_addresses, + listener.port, + manhole_settings=self.config.server.manhole_settings, + manhole_globals={"hs": self}, + ) ) else: raise ConfigError( @@ -294,9 +297,11 @@ class SynapseHomeServer(HomeServer): ) else: if isinstance(listener, TCPListenerConfig): - _base.listen_metrics( - listener.bind_addresses, - listener.port, + self._metrics_listeners.extend( + _base.listen_metrics( + listener.bind_addresses, + listener.port, + ) ) else: raise ConfigError( @@ -340,12 +345,23 @@ def load_or_generate_config(argv_options: List[str]) -> HomeServerConfig: return config -def setup(config: HomeServerConfig) -> SynapseHomeServer: +def setup( + config: HomeServerConfig, + reactor: Optional[ISynapseReactor] = None, + freeze: bool = True, +) -> SynapseHomeServer: """ Create and setup a Synapse homeserver instance given a configuration. Args: config: The configuration for the homeserver. + reactor: Optionally provide a reactor to use. Can be useful in different + scenarios that you want control over the reactor, such as tests. + freeze: whether to freeze the homeserver base objects in the garbage collector. + May improve garbage collection performance by marking objects with an effectively + static lifetime as frozen so they don't need to be considered for cleanup. + If you ever want to `shutdown` the homeserver, this needs to be + False otherwise the homeserver cannot be garbage collected after `shutdown`. Returns: A homeserver instance. @@ -384,6 +400,7 @@ def setup(config: HomeServerConfig) -> SynapseHomeServer: config.server.server_name, config=config, version_string=f"Synapse/{VERSION}", + reactor=reactor, ) setup_logging(hs, config, use_worker_options=False) @@ -405,7 +422,7 @@ def setup(config: HomeServerConfig) -> SynapseHomeServer: # Loading the provider metadata also ensures the provider config is valid. await oidc.load_metadata() - await _base.start(hs) + await _base.start(hs, freeze) hs.get_datastores().main.db_pool.updates.start_doing_background_updates() diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 7b8e7fe700..4bbc33cba2 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -29,9 +29,6 @@ from prometheus_client import Gauge from twisted.internet import defer from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import ( - run_as_background_process, -) from synapse.types import JsonDict from synapse.util.constants import ( MILLISECONDS_PER_SECOND, @@ -87,8 +84,6 @@ def phone_stats_home( stats: JsonDict, stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process, ) -> "defer.Deferred[None]": - server_name = hs.hostname - async def _phone_stats_home( hs: "HomeServer", stats: JsonDict, @@ -202,8 +197,8 @@ def phone_stats_home( except Exception as e: logger.warning("Error reporting stats: %s", e) - return run_as_background_process( - "phone_stats_home", server_name, _phone_stats_home, hs, stats, stats_process + return hs.run_as_background_process( + "phone_stats_home", _phone_stats_home, hs, stats, stats_process ) @@ -265,9 +260,8 @@ def start_phone_stats_home(hs: "HomeServer") -> None: float(hs.config.server.max_mau_value) ) - return run_as_background_process( + return hs.run_as_background_process( "generate_monthly_active_users", - server_name, _generate_monthly_active_users, ) @@ -287,10 +281,16 @@ def start_phone_stats_home(hs: "HomeServer") -> None: # We need to defer this init for the cases that we daemonize # otherwise the process ID we get is that of the non-daemon process - clock.call_later(0, performance_stats_init) + clock.call_later( + 0, + performance_stats_init, + ) # We wait 5 minutes to send the first set of stats as the server can # be quite busy the first few minutes clock.call_later( - INITIAL_DELAY_BEFORE_FIRST_PHONE_HOME_SECONDS, phone_stats_home, hs, stats + INITIAL_DELAY_BEFORE_FIRST_PHONE_HOME_SECONDS, + phone_stats_home, + hs, + stats, ) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 2d8d382e68..1d0735ca1d 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -23,15 +23,33 @@ import logging import re from enum import Enum -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern, Sequence +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + Optional, + Pattern, + Sequence, + cast, +) import attr from netaddr import IPSet +from twisted.internet import reactor + from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, UserID +from synapse.types import ( + DeviceListUpdates, + ISynapseThreadlessReactor, + JsonDict, + JsonMapping, + UserID, +) from synapse.util.caches.descriptors import _CacheContext, cached +from synapse.util.clock import Clock if TYPE_CHECKING: from synapse.appservice.api import ApplicationServiceApi @@ -98,6 +116,15 @@ class ApplicationService: self.sender = sender # The application service user should be part of the server's domain. self.server_name = sender.domain # nb must be called this for @cached + + # Ideally we would require passing in the `HomeServer` `Clock` instance. + # However this is not currently possible as there are places which use + # `@cached` that aren't aware of the `HomeServer` instance. + # nb must be called this for @cached + self.clock = Clock( + cast(ISynapseThreadlessReactor, reactor), server_name=self.server_name + ) # type: ignore[multiple-internal-clocks] + self.namespaces = self._check_namespaces(namespaces) self.id = id self.ip_range_whitelist = ip_range_whitelist diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index c8678406a1..b4de759b67 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -81,7 +81,6 @@ from synapse.appservice import ( from synapse.appservice.api import ApplicationServiceApi from synapse.events import EventBase from synapse.logging.context import run_in_background -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main import DataStore from synapse.types import DeviceListUpdates, JsonMapping from synapse.util.clock import Clock @@ -200,6 +199,7 @@ class _ServiceQueuer: ) self.server_name = hs.hostname self.clock = hs.get_clock() + self.hs = hs self._store = hs.get_datastores().main def start_background_request(self, service: ApplicationService) -> None: @@ -207,9 +207,7 @@ class _ServiceQueuer: if service.id in self.requests_in_flight: return - run_as_background_process( - "as-sender", self.server_name, self._send_request, service - ) + self.hs.run_as_background_process("as-sender", self._send_request, service) async def _send_request(self, service: ApplicationService) -> None: # sanity-check: we shouldn't get here if this service already has a sender @@ -361,6 +359,7 @@ class _TransactionController: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self.clock = hs.get_clock() + self.hs = hs self.store = hs.get_datastores().main self.as_api = hs.get_application_service_api() @@ -448,6 +447,7 @@ class _TransactionController: recoverer = self.RECOVERER_CLASS( self.server_name, self.clock, + self.hs, self.store, self.as_api, service, @@ -494,6 +494,7 @@ class _Recoverer: self, server_name: str, clock: Clock, + hs: "HomeServer", store: DataStore, as_api: ApplicationServiceApi, service: ApplicationService, @@ -501,6 +502,7 @@ class _Recoverer: ): self.server_name = server_name self.clock = clock + self.hs = hs self.store = store self.as_api = as_api self.service = service @@ -513,9 +515,8 @@ class _Recoverer: logger.info("Scheduling retries on %s in %fs", self.service.id, delay) self.scheduled_recovery = self.clock.call_later( delay, - run_as_background_process, + self.hs.run_as_background_process, "as-recoverer", - self.server_name, self.retry, ) @@ -535,9 +536,8 @@ class _Recoverer: if self.scheduled_recovery: self.clock.cancel_call_later(self.scheduled_recovery) # Run a retry, which will resechedule a recovery if it fails. - run_as_background_process( + self.hs.run_as_background_process( "retry", - self.server_name, self.retry, ) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 0531ae7875..9dde4c4003 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -345,7 +345,9 @@ def setup_logging( # Add a SIGHUP handler to reload the logging configuration, if one is available. from synapse.app import _base as appbase - appbase.register_sighup(_reload_logging_config, log_config_path) + appbase.register_sighup( + hs.get_instance_id(), _reload_logging_config, log_config_path + ) # Log immediately so we can grep backwards. logger.warning("***** STARTING SERVER *****") diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index eac2d776f9..258bc29357 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -172,7 +172,7 @@ class Keyring: _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]] ] = BatchingQueue( name="keyring_server", - server_name=self.server_name, + hs=hs, clock=hs.get_clock(), # The method called to fetch each key process_batch_callback=self._inner_fetch_key_requests, @@ -194,6 +194,14 @@ class Keyring: valid_until_ts=2**63, # fake future timestamp ) + def shutdown(self) -> None: + """ + Prepares the KeyRing for garbage collection by shutting down it's queues. + """ + self._fetch_keys_queue.shutdown() + for key_fetcher in self._key_fetchers: + key_fetcher.shutdown() + async def verify_json_for_server( self, server_name: str, @@ -479,11 +487,17 @@ class KeyFetcher(metaclass=abc.ABCMeta): self.server_name = hs.hostname self._queue = BatchingQueue( name=self.__class__.__name__, - server_name=self.server_name, + hs=hs, clock=hs.get_clock(), process_batch_callback=self._fetch_keys, ) + def shutdown(self) -> None: + """ + Prepares the KeyFetcher for garbage collection by shutting down it's queue. + """ + self._queue.shutdown() + async def get_keys( self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int ) -> Dict[str, FetchKeyResult]: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 41595043d1..8c91336dbc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -148,6 +148,7 @@ class FederationClient(FederationBase): self._get_pdu_cache: ExpiringCache[str, Tuple[EventBase, str]] = ExpiringCache( cache_name="get_pdu_cache", server_name=self.server_name, + hs=self.hs, clock=self._clock, max_len=1000, expiry_ms=120 * 1000, @@ -167,6 +168,7 @@ class FederationClient(FederationBase): ] = ExpiringCache( cache_name="get_room_hierarchy_cache", server_name=self.server_name, + hs=self.hs, clock=self._clock, max_len=1000, expiry_ms=5 * 60 * 1000, diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 2fdee9ac54..759df9836b 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -144,6 +144,9 @@ class FederationRemoteSendQueue(AbstractFederationSender): self.clock.looping_call(self._clear_queue, 30 * 1000) + def shutdown(self) -> None: + """Stops this federation sender instance from sending further transactions.""" + def _next_pos(self) -> int: pos = self.pos self.pos += 1 diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 8e3619d1bc..4410ffc5c5 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -168,7 +168,6 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.types import ( @@ -232,6 +231,11 @@ WAKEUP_INTERVAL_BETWEEN_DESTINATIONS_SEC = 5 class AbstractFederationSender(metaclass=abc.ABCMeta): + @abc.abstractmethod + def shutdown(self) -> None: + """Stops this federation sender instance from sending further transactions.""" + raise NotImplementedError() + @abc.abstractmethod def notify_new_events(self, max_token: RoomStreamToken) -> None: """This gets called when we have some new events we might want to @@ -326,6 +330,7 @@ class _DestinationWakeupQueue: _MAX_TIME_IN_QUEUE = 30.0 sender: "FederationSender" = attr.ib() + hs: "HomeServer" = attr.ib() server_name: str = attr.ib() """ Our homeserver name (used to label metrics) (`hs.hostname`). @@ -453,18 +458,30 @@ class FederationSender(AbstractFederationSender): 1.0 / hs.config.ratelimiting.federation_rr_transactions_per_room_per_second ) self._destination_wakeup_queue = _DestinationWakeupQueue( - self, self.server_name, self.clock, max_delay_s=rr_txn_interval_per_room_s + self, + hs, + self.server_name, + self.clock, + max_delay_s=rr_txn_interval_per_room_s, ) + # It is important for `_is_shutdown` to be instantiated before the looping call + # for `wake_destinations_needing_catchup`. + self._is_shutdown = False + # Regularly wake up destinations that have outstanding PDUs to be caught up self.clock.looping_call_now( - run_as_background_process, + self.hs.run_as_background_process, WAKEUP_RETRY_PERIOD_SEC * 1000.0, "wake_destinations_needing_catchup", - self.server_name, self._wake_destinations_needing_catchup, ) + def shutdown(self) -> None: + self._is_shutdown = True + for queue in self._per_destination_queues.values(): + queue.shutdown() + def _get_per_destination_queue( self, destination: str ) -> Optional[PerDestinationQueue]: @@ -503,16 +520,15 @@ class FederationSender(AbstractFederationSender): return # fire off a processing loop in the background - run_as_background_process( + self.hs.run_as_background_process( "process_event_queue_for_federation", - self.server_name, self._process_event_queue_loop, ) async def _process_event_queue_loop(self) -> None: try: self._is_processing = True - while True: + while not self._is_shutdown: last_token = await self.store.get_federation_out_pos("events") ( next_token, @@ -1123,7 +1139,7 @@ class FederationSender(AbstractFederationSender): last_processed: Optional[str] = None - while True: + while not self._is_shutdown: destinations_to_wake = ( await self.store.get_catch_up_outstanding_destinations(last_processed) ) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 4c844d403a..845af92fac 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -28,6 +28,8 @@ from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tupl import attr from prometheus_client import Counter +from twisted.internet import defer + from synapse.api.constants import EduTypes from synapse.api.errors import ( FederationDeniedError, @@ -41,7 +43,6 @@ from synapse.handlers.presence import format_user_presence_state from synapse.logging import issue9533_logger from synapse.logging.opentracing import SynapseTags, set_tag from synapse.metrics import SERVER_NAME_LABEL, sent_transactions_counter -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict, ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.visibility import filter_events_for_server @@ -79,6 +80,7 @@ MAX_PRESENCE_STATES_PER_EDU = 50 class PerDestinationQueue: """ Manages the per-destination transmission queues. + Runs until `shutdown()` is called on the queue. Args: hs @@ -94,6 +96,7 @@ class PerDestinationQueue: destination: str, ): self.server_name = hs.hostname + self._hs = hs self._clock = hs.get_clock() self._storage_controllers = hs.get_storage_controllers() self._store = hs.get_datastores().main @@ -117,6 +120,8 @@ class PerDestinationQueue: self._destination = destination self.transmission_loop_running = False + self._transmission_loop_enabled = True + self.active_transmission_loop: Optional[defer.Deferred] = None # Flag to signal to any running transmission loop that there is new data # queued up to be sent. @@ -171,6 +176,20 @@ class PerDestinationQueue: def __str__(self) -> str: return "PerDestinationQueue[%s]" % self._destination + def shutdown(self) -> None: + """Instruct the queue to stop processing any further requests""" + self._transmission_loop_enabled = False + # The transaction manager must be shutdown before cancelling the active + # transmission loop. Otherwise the transmission loop can enter a new cycle of + # sleeping before retrying since the shutdown flag of the _transaction_manager + # hasn't been set yet. + self._transaction_manager.shutdown() + try: + if self.active_transmission_loop is not None: + self.active_transmission_loop.cancel() + except Exception: + pass + def pending_pdu_count(self) -> int: return len(self._pending_pdus) @@ -309,11 +328,14 @@ class PerDestinationQueue: ) return + if not self._transmission_loop_enabled: + logger.warning("Shutdown has been requested. Not sending transaction") + return + logger.debug("TX [%s] Starting transaction loop", self._destination) - run_as_background_process( + self.active_transmission_loop = self._hs.run_as_background_process( "federation_transaction_transmission_loop", - self.server_name, self._transaction_transmission_loop, ) @@ -321,13 +343,13 @@ class PerDestinationQueue: pending_pdus: List[EventBase] = [] try: self.transmission_loop_running = True - # This will throw if we wouldn't retry. We do this here so we fail # quickly, but we will later check this again in the http client, # hence why we throw the result away. await get_retry_limiter( destination=self._destination, our_server_name=self.server_name, + hs=self._hs, clock=self._clock, store=self._store, ) @@ -339,7 +361,7 @@ class PerDestinationQueue: # not caught up yet return - while True: + while self._transmission_loop_enabled: self._new_data_to_send = False async with _TransactionQueueManager(self) as ( @@ -352,8 +374,8 @@ class PerDestinationQueue: # If we've gotten told about new things to send during # checking for things to send, we try looking again. # Otherwise new PDUs or EDUs might arrive in the meantime, - # but not get sent because we hold the - # `transmission_loop_running` flag. + # but not get sent because we currently have an + # `_active_transmission_loop` running. if self._new_data_to_send: continue else: @@ -442,6 +464,7 @@ class PerDestinationQueue: ) finally: # We want to be *very* sure we clear this after we stop processing + self.active_transmission_loop = None self.transmission_loop_running = False async def _catch_up_transmission_loop(self) -> None: @@ -469,7 +492,7 @@ class PerDestinationQueue: last_successful_stream_ordering: int = _tmp_last_successful_stream_ordering # get at most 50 catchup room/PDUs - while True: + while self._transmission_loop_enabled: event_ids = await self._store.get_catch_up_room_event_ids( self._destination, last_successful_stream_ordering ) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index b548d9ed70..f47c011487 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -72,6 +72,12 @@ class TransactionManager: # HACK to get unique tx id self._next_txn_id = int(self.clock.time_msec()) + self._is_shutdown = False + + def shutdown(self) -> None: + self._is_shutdown = True + self._transport_layer.shutdown() + @measure_func("_send_new_transaction") async def send_new_transaction( self, @@ -86,6 +92,12 @@ class TransactionManager: edus: List of EDUs to send """ + if self._is_shutdown: + logger.warning( + "TransactionManager has been shutdown, not sending transaction" + ) + return + # Make a transaction-sending opentracing span. This span follows on from # all the edus in that transaction. This needs to be done since there is # no active span here, so if the edus were not received by the remote the diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 5a5dc45f10..02e56e8e27 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -70,6 +70,9 @@ class TransportLayerClient: self.client = hs.get_federation_http_client() self._is_mine_server_name = hs.is_mine_server_name + def shutdown(self) -> None: + self.client.shutdown() + async def get_room_state_ids( self, destination: str, room_id: str, event_id: str ) -> JsonDict: diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 39a22b8cbb..eed50ef69a 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -37,10 +37,8 @@ logger = logging.getLogger(__name__) class AccountValidityHandler: def __init__(self, hs: "HomeServer"): - self.hs = hs - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.hs = hs # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname self.config = hs.config self.store = hs.get_datastores().main self.send_email_handler = hs.get_send_email_handler() diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index bf36cf39a1..6536d9fe51 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -47,7 +47,6 @@ from synapse.metrics import ( event_processing_loop_room_count, ) from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.storage.databases.main.directory import RoomAliasMapping @@ -76,9 +75,8 @@ events_processed_counter = Counter( class ApplicationServicesHandler: def __init__(self, hs: "HomeServer"): - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname + self.hs = hs # nb must be called this for @wrap_as_background_process self.store = hs.get_datastores().main self.is_mine_id = hs.is_mine_id self.appservice_api = hs.get_application_service_api() @@ -171,8 +169,8 @@ class ApplicationServicesHandler: except Exception: logger.error("Application Services Failure") - run_as_background_process( - "as_scheduler", self.server_name, start_scheduler + self.hs.run_as_background_process( + "as_scheduler", start_scheduler ) self.started_scheduler = True diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index c0684380a7..204dffd288 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Optional from synapse.api.constants import Membership from synapse.api.errors import SynapseError -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.deactivate_account import ( ReplicationNotifyAccountDeactivatedServlet, ) @@ -272,8 +271,8 @@ class DeactivateAccountHandler: pending deactivation, if it isn't already running. """ if not self._user_parter_running: - run_as_background_process( - "user_parter_loop", self.server_name, self._user_parter_loop + self.hs.run_as_background_process( + "user_parter_loop", self._user_parter_loop ) async def _user_parter_loop(self) -> None: diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index d47e3fd263..79dd3e8416 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -24,9 +24,6 @@ from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions -from synapse.metrics.background_process_metrics import ( - run_as_background_process, -) from synapse.replication.http.delayed_events import ( ReplicationAddedDelayedEventRestServlet, ) @@ -58,6 +55,7 @@ logger = logging.getLogger(__name__) class DelayedEventsHandler: def __init__(self, hs: "HomeServer"): + self.hs = hs self.server_name = hs.hostname self._store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() @@ -94,7 +92,10 @@ class DelayedEventsHandler: hs.get_notifier().add_replication_callback(self.notify_new_event) # Kick off again (without blocking) to catch any missed notifications # that may have fired before the callback was added. - self._clock.call_later(0, self.notify_new_event) + self._clock.call_later( + 0, + self.notify_new_event, + ) # Delayed events that are already marked as processed on startup might not have been # sent properly on the last run of the server, so unmark them to send them again. @@ -112,15 +113,14 @@ class DelayedEventsHandler: self._schedule_next_at(next_send_ts) # Can send the events in background after having awaited on marking them as processed - run_as_background_process( + self.hs.run_as_background_process( "_send_events", - self.server_name, self._send_events, events, ) - self._initialized_from_db = run_as_background_process( - "_schedule_db_events", self.server_name, _schedule_db_events + self._initialized_from_db = self.hs.run_as_background_process( + "_schedule_db_events", _schedule_db_events ) else: self._repl_client = ReplicationAddedDelayedEventRestServlet.make_client(hs) @@ -145,9 +145,7 @@ class DelayedEventsHandler: finally: self._event_processing = False - run_as_background_process( - "delayed_events.notify_new_event", self.server_name, process - ) + self.hs.run_as_background_process("delayed_events.notify_new_event", process) async def _unsafe_process_new_event(self) -> None: # We purposefully fetch the current max room stream ordering before @@ -542,9 +540,8 @@ class DelayedEventsHandler: if self._next_delayed_event_call is None: self._next_delayed_event_call = self._clock.call_later( delay_sec, - run_as_background_process, + self.hs.run_as_background_process, "_send_on_timeout", - self.server_name, self._send_on_timeout, ) else: diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9509ac422e..c6024597b7 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -47,7 +47,6 @@ from synapse.api.errors import ( ) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.replication.http.devices import ( @@ -125,7 +124,7 @@ class DeviceHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname # nb must be called this for @measure_func self.clock = hs.get_clock() # nb must be called this for @measure_func - self.hs = hs + self.hs = hs # nb must be called this for @wrap_as_background_process self.store = cast("GenericWorkerStore", hs.get_datastores().main) self.notifier = hs.get_notifier() self.state = hs.get_state_handler() @@ -191,10 +190,9 @@ class DeviceHandler: and self._delete_stale_devices_after is not None ): self.clock.looping_call( - run_as_background_process, + self.hs.run_as_background_process, DELETE_STALE_DEVICES_INTERVAL_MS, desc="delete_stale_devices", - server_name=self.server_name, func=self._delete_stale_devices, ) @@ -963,10 +961,9 @@ class DeviceWriterHandler(DeviceHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) + self.server_name = hs.hostname # nb must be called this for @measure_func + self.hs = hs # nb must be called this for @wrap_as_background_process - self.server_name = ( - hs.hostname - ) # nb must be called this for @measure_func and @wrap_as_background_process # We only need to poke the federation sender explicitly if its on the # same instance. Other federation sender instances will get notified by # `synapse.app.generic_worker.FederationSenderHandler` when it sees it @@ -1444,7 +1441,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): def __init__(self, hs: "HomeServer", device_handler: DeviceWriterHandler): super().__init__(hs) - self.server_name = hs.hostname + self.hs = hs self.federation = hs.get_federation_client() self.server_name = hs.hostname # nb must be called this for @measure_func self.clock = hs.get_clock() # nb must be called this for @measure_func @@ -1468,6 +1465,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache( cache_name="device_update_edu", server_name=self.server_name, + hs=self.hs, clock=self.clock, max_len=10000, expiry_ms=30 * 60 * 1000, @@ -1477,9 +1475,8 @@ class DeviceListUpdater(DeviceListWorkerUpdater): # Attempt to resync out of sync device lists every 30s. self._resync_retry_lock = Lock() self.clock.looping_call( - run_as_background_process, + self.hs.run_as_background_process, 30 * 1000, - server_name=self.server_name, func=self._maybe_retry_device_resync, desc="_maybe_retry_device_resync", ) @@ -1599,9 +1596,8 @@ class DeviceListUpdater(DeviceListWorkerUpdater): if resync: # We mark as stale up front in case we get restarted. await self.store.mark_remote_users_device_caches_as_stale([user_id]) - run_as_background_process( + self.hs.run_as_background_process( "_maybe_retry_device_resync", - self.server_name, self.multi_user_device_resync, [user_id], False, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 41fb3076c3..adc20f4ad0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -72,7 +72,6 @@ from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import nested_logging_context from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import NOT_SPAM from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.invite_rule import InviteRule @@ -188,9 +187,8 @@ class FederationHandler: # any partial-state-resync operations which were in flight when we # were shut down. if not hs.config.worker.worker_app: - run_as_background_process( + self.hs.run_as_background_process( "resume_sync_partial_state_room", - self.server_name, self._resume_partial_state_room_sync, ) @@ -318,9 +316,8 @@ class FederationHandler: logger.debug( "_maybe_backfill_inner: all backfill points are *after* current depth. Trying again with later backfill points." ) - run_as_background_process( + self.hs.run_as_background_process( "_maybe_backfill_inner_anyway_with_max_depth", - self.server_name, self.maybe_backfill, room_id=room_id, # We use `MAX_DEPTH` so that we find all backfill points next @@ -802,9 +799,8 @@ class FederationHandler: # lots of requests for missing prev_events which we do actually # have. Hence we fire off the background task, but don't wait for it. - run_as_background_process( + self.hs.run_as_background_process( "handle_queued_pdus", - self.server_name, self._handle_queued_pdus, room_queue, ) @@ -1877,9 +1873,8 @@ class FederationHandler: room_id=room_id, ) - run_as_background_process( + self.hs.run_as_background_process( desc="sync_partial_state_room", - server_name=self.server_name, func=_sync_partial_state_room_wrapper, ) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 59886f04c4..d6390b79c7 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -81,7 +81,6 @@ from synapse.logging.opentracing import ( trace, ) from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, ) @@ -153,6 +152,7 @@ class FederationEventHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname + self.hs = hs self._clock = hs.get_clock() self._store = hs.get_datastores().main self._state_store = hs.get_datastores().state @@ -175,6 +175,7 @@ class FederationEventHandler: ) self._notifier = hs.get_notifier() + self._server_name = hs.hostname self._is_mine_id = hs.is_mine_id self._is_mine_server_name = hs.is_mine_server_name self._instance_name = hs.get_instance_name() @@ -974,9 +975,8 @@ class FederationEventHandler: # Process previously failed backfill events in the background to not waste # time on something that is likely to fail again. if len(events_with_failed_pull_attempts) > 0: - run_as_background_process( + self.hs.run_as_background_process( "_process_new_pulled_events_with_failed_pull_attempts", - self.server_name, _process_new_pulled_events, events_with_failed_pull_attempts, ) @@ -1568,9 +1568,8 @@ class FederationEventHandler: resync = True if resync: - run_as_background_process( + self.hs.run_as_background_process( "resync_device_due_to_pdu", - self.server_name, self._resync_device, event.sender, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4ff8b3704b..e874b60000 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -67,7 +67,6 @@ from synapse.handlers.directory import DirectoryHandler from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_events import ReplicationSendEventsRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( @@ -99,6 +98,7 @@ class MessageHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname + self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() self.state = hs.get_state_handler() @@ -113,8 +113,8 @@ class MessageHandler: self._scheduled_expiry: Optional[IDelayedCall] = None if not hs.config.worker.worker_app: - run_as_background_process( - "_schedule_next_expiry", self.server_name, self._schedule_next_expiry + self.hs.run_as_background_process( + "_schedule_next_expiry", self._schedule_next_expiry ) async def get_room_data( @@ -444,9 +444,8 @@ class MessageHandler: self._scheduled_expiry = self.clock.call_later( delay, - run_as_background_process, + self.hs.run_as_background_process, "_expire_event", - self.server_name, self._expire_event, event_id, ) @@ -548,9 +547,8 @@ class EventCreationHandler: and self.config.server.cleanup_extremities_with_dummy_events ): self.clock.looping_call( - lambda: run_as_background_process( + lambda: self.hs.run_as_background_process( "send_dummy_events_to_fill_extremities", - self.server_name, self._send_dummy_events_to_fill_extremities, ), 5 * 60 * 1000, @@ -570,6 +568,7 @@ class EventCreationHandler: self._external_cache_joined_hosts_updates = ExpiringCache( cache_name="_external_cache_joined_hosts_updates", server_name=self.server_name, + hs=self.hs, clock=self.clock, expiry_ms=30 * 60 * 1000, ) @@ -2113,9 +2112,8 @@ class EventCreationHandler: if event.type == EventTypes.Message: # We don't want to block sending messages on any presence code. This # matters as sometimes presence code can take a while. - run_as_background_process( + self.hs.run_as_background_process( "bump_presence_active_time", - self.server_name, self._bump_active_time, requester.user, requester.device_id, diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index df1a7e714c..02a67581e7 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -29,7 +29,6 @@ from synapse.api.filtering import Filter from synapse.events.utils import SerializeEventConfig from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging.opentracing import trace -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.admin._base import assert_user_is_admin from synapse.streams.config import PaginationConfig from synapse.types import ( @@ -116,10 +115,9 @@ class PaginationHandler: logger.info("Setting up purge job with config: %s", job) self.clock.looping_call( - run_as_background_process, + self.hs.run_as_background_process, job.interval, "purge_history_for_rooms_in_range", - self.server_name, self.purge_history_for_rooms_in_range, job.shortest_max_lifetime, job.longest_max_lifetime, @@ -244,9 +242,8 @@ class PaginationHandler: # We want to purge everything, including local events, and to run the purge in # the background so that it's not blocking any other operation apart from # other purges in the same room. - run_as_background_process( + self.hs.run_as_background_process( PURGE_HISTORY_ACTION_NAME, - self.server_name, self.purge_history, room_id, token, @@ -604,9 +601,8 @@ class PaginationHandler: # Otherwise, we can backfill in the background for eventual # consistency's sake but we don't need to block the client waiting # for a costly federation call and processing. - run_as_background_process( + self.hs.run_as_background_process( "maybe_backfill_in_the_background", - self.server_name, self.hs.get_federation_handler().maybe_backfill, room_id, curr_topo, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 4d246fadbd..1610683066 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -107,7 +107,6 @@ from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.replication.http.presence import ( @@ -537,19 +536,15 @@ class WorkerPresenceHandler(BasePresenceHandler): self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) self._set_state_client = ReplicationPresenceSetState.make_client(hs) - self._send_stop_syncing_loop = self.clock.looping_call( - self.send_stop_syncing, UPDATE_SYNCING_USERS_MS - ) - - hs.get_clock().add_system_event_trigger( - "before", - "shutdown", - run_as_background_process, - "generic_presence.on_shutdown", - self.server_name, - self._on_shutdown, + self.clock.looping_call(self.send_stop_syncing, UPDATE_SYNCING_USERS_MS) + + hs.register_async_shutdown_handler( + phase="before", + eventType="shutdown", + shutdown_func=self._on_shutdown, ) + @wrap_as_background_process("WorkerPresenceHandler._on_shutdown") async def _on_shutdown(self) -> None: if self._track_presence: self.hs.get_replication_command_handler().send_command( @@ -779,9 +774,7 @@ class WorkerPresenceHandler(BasePresenceHandler): class PresenceHandler(BasePresenceHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname self.wheel_timer: WheelTimer[str] = WheelTimer() self.notifier = hs.get_notifier() @@ -842,13 +835,10 @@ class PresenceHandler(BasePresenceHandler): # have not yet been persisted self.unpersisted_users_changes: Set[str] = set() - hs.get_clock().add_system_event_trigger( - "before", - "shutdown", - run_as_background_process, - "presence.on_shutdown", - self.server_name, - self._on_shutdown, + hs.register_async_shutdown_handler( + phase="before", + eventType="shutdown", + shutdown_func=self._on_shutdown, ) # Keeps track of the number of *ongoing* syncs on this process. While @@ -881,7 +871,10 @@ class PresenceHandler(BasePresenceHandler): # The initial delay is to allow disconnected clients a chance to # reconnect before we treat them as offline. self.clock.call_later( - 30, self.clock.looping_call, self._handle_timeouts, 5000 + 30, + self.clock.looping_call, + self._handle_timeouts, + 5000, ) # Presence information is persisted, whether or not it is being tracked @@ -908,6 +901,7 @@ class PresenceHandler(BasePresenceHandler): self._event_pos = self.store.get_room_max_stream_ordering() self._event_processing = False + @wrap_as_background_process("PresenceHandler._on_shutdown") async def _on_shutdown(self) -> None: """Gets called when shutting down. This lets us persist any updates that we haven't yet persisted, e.g. updates that only changes some internal @@ -1539,8 +1533,8 @@ class PresenceHandler(BasePresenceHandler): finally: self._event_processing = False - run_as_background_process( - "presence.notify_new_event", self.server_name, _process_presence + self.hs.run_as_background_process( + "presence.notify_new_event", _process_presence ) async def _unsafe_process(self) -> None: diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index dbff28e7fb..9dda89d85b 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -56,8 +56,8 @@ class ProfileHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname # nb must be called this for @cached + self.clock = hs.get_clock() # nb must be called this for @cached self.store = hs.get_datastores().main - self.clock = hs.get_clock() self.hs = hs self.federation = hs.get_federation_client() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 5761a7f70b..c3ff0cfaf8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -23,7 +23,14 @@ """Contains functions for registering clients.""" import logging -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, TypedDict +from typing import ( + TYPE_CHECKING, + Iterable, + List, + Optional, + Tuple, + TypedDict, +) from prometheus_client import Counter diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 623823acb0..2ab9b70f8c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -50,7 +50,6 @@ from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging import opentracing from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.push import ReplicationCopyPusherRestServlet from synapse.storage.databases.main.state_deltas import StateDelta from synapse.storage.invite_rule import InviteRule @@ -2190,7 +2189,10 @@ class RoomForgetterHandler(StateDeltasHandler): self._notifier.add_replication_callback(self.notify_new_event) # We kick this off to pick up outstanding work from before the last restart. - self._clock.call_later(0, self.notify_new_event) + self._clock.call_later( + 0, + self.notify_new_event, + ) def notify_new_event(self) -> None: """Called when there may be more deltas to process""" @@ -2205,9 +2207,7 @@ class RoomForgetterHandler(StateDeltasHandler): finally: self._is_processing = False - run_as_background_process( - "room_forgetter.notify_new_event", self.server_name, process - ) + self._hs.run_as_background_process("room_forgetter.notify_new_event", process) async def _unsafe_process(self) -> None: # If self.pos is None then means we haven't fetched it from DB diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index eec420cbb1..735cfa0a0f 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -224,7 +224,7 @@ class SsoHandler: ) # a lock on the mappings - self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) + self._mapping_lock = Linearizer(clock=hs.get_clock(), name="sso_user_mapping") # a map from session id to session data self._username_mapping_sessions: Dict[str, UsernameMappingSession] = {} diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index a2602ea818..5b4a2cc62d 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -33,7 +33,6 @@ from typing import ( from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.state_deltas import StateDelta from synapse.types import JsonDict from synapse.util.events import get_plain_text_topic_from_event_content @@ -75,7 +74,10 @@ class StatsHandler: # We kick this off so that we don't have to wait for a change before # we start populating stats - self.clock.call_later(0, self.notify_new_event) + self.clock.call_later( + 0, + self.notify_new_event, + ) def notify_new_event(self) -> None: """Called when there may be more deltas to process""" @@ -90,7 +92,7 @@ class StatsHandler: finally: self._is_processing = False - run_as_background_process("stats.notify_new_event", self.server_name, process) + self.hs.run_as_background_process("stats.notify_new_event", process) async def _unsafe_process(self) -> None: # If self.pos is None then means we haven't fetched it from DB diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c0341c5654..6f0522d5bb 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -323,6 +323,7 @@ class SyncHandler: ] = ExpiringCache( cache_name="lazy_loaded_members_cache", server_name=self.server_name, + hs=hs, clock=self.clock, max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, @@ -982,6 +983,7 @@ class SyncHandler: logger.debug("creating LruCache for %r", cache_key) cache = LruCache( max_size=LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE, + clock=self.clock, server_name=self.server_name, ) self.lazy_loaded_members_cache[cache_key] = cache diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 6a7b36ea0c..77c5b747c3 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -28,7 +28,6 @@ from synapse.api.constants import EduTypes from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.replication.tcp.streams import TypingStream @@ -78,11 +77,10 @@ class FollowerTypingHandler: """ def __init__(self, hs: "HomeServer"): + self.hs = hs # nb must be called this for @wrap_as_background_process self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id self.is_mine_server_name = hs.is_mine_server_name @@ -144,9 +142,8 @@ class FollowerTypingHandler: if self.federation and self.is_mine_id(member.user_id): last_fed_poke = self._member_last_federation_poke.get(member, None) if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: - run_as_background_process( + self.hs.run_as_background_process( "typing._push_remote", - self.server_name, self._push_remote, member=member, typing=True, @@ -220,9 +217,8 @@ class FollowerTypingHandler: self._rooms_updated.add(row.room_id) if self.federation: - run_as_background_process( + self.hs.run_as_background_process( "_send_changes_in_typing_to_remotes", - self.server_name, self._send_changes_in_typing_to_remotes, row.room_id, prev_typing, @@ -384,9 +380,8 @@ class TypingWriterHandler(FollowerTypingHandler): def _push_update(self, member: RoomMember, typing: bool) -> None: if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. - run_as_background_process( + self.hs.run_as_background_process( "typing._push_remote", - self.server_name, self._push_remote, member, typing, diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 130099a239..28961f5925 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -36,7 +36,6 @@ from synapse.api.constants import ( from synapse.api.errors import Codes, SynapseError from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.state_deltas import StateDelta from synapse.storage.databases.main.user_directory import SearchResult from synapse.storage.roommember import ProfileInfo @@ -137,11 +136,15 @@ class UserDirectoryHandler(StateDeltasHandler): # We kick this off so that we don't have to wait for a change before # we start populating the user directory - self.clock.call_later(0, self.notify_new_event) + self.clock.call_later( + 0, + self.notify_new_event, + ) # Kick off the profile refresh process on startup self._refresh_remote_profiles_call_later = self.clock.call_later( - 10, self.kick_off_remote_profile_refresh_process + 10, + self.kick_off_remote_profile_refresh_process, ) async def search_users( @@ -193,9 +196,7 @@ class UserDirectoryHandler(StateDeltasHandler): self._is_processing = False self._is_processing = True - run_as_background_process( - "user_directory.notify_new_event", self.server_name, process - ) + self._hs.run_as_background_process("user_directory.notify_new_event", process) async def handle_local_profile_change( self, user_id: str, profile: ProfileInfo @@ -609,8 +610,8 @@ class UserDirectoryHandler(StateDeltasHandler): self._is_refreshing_remote_profiles = False self._is_refreshing_remote_profiles = True - run_as_background_process( - "user_directory.refresh_remote_profiles", self.server_name, process + self._hs.run_as_background_process( + "user_directory.refresh_remote_profiles", process ) async def _unsafe_refresh_remote_profiles(self) -> None: @@ -655,8 +656,9 @@ class UserDirectoryHandler(StateDeltasHandler): if not users: return _, _, next_try_at_ts = users[0] + delay = ((next_try_at_ts - self.clock.time_msec()) // 1000) + 2 self._refresh_remote_profiles_call_later = self.clock.call_later( - ((next_try_at_ts - self.clock.time_msec()) // 1000) + 2, + delay, self.kick_off_remote_profile_refresh_process, ) @@ -692,9 +694,8 @@ class UserDirectoryHandler(StateDeltasHandler): self._is_refreshing_remote_profiles_for_servers.remove(server_name) self._is_refreshing_remote_profiles_for_servers.add(server_name) - run_as_background_process( + self._hs.run_as_background_process( "user_directory.refresh_remote_profiles_for_remote_server", - self.server_name, process, ) diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py index 0b375790dd..ca1e2b166c 100644 --- a/synapse/handlers/worker_lock.py +++ b/synapse/handlers/worker_lock.py @@ -37,13 +37,13 @@ from weakref import WeakSet import attr from twisted.internet import defer -from twisted.internet.interfaces import IReactorTime from synapse.logging.context import PreserveLoggingContext from synapse.logging.opentracing import start_active_span from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.databases.main.lock import Lock, LockStore from synapse.util.async_helpers import timeout_deferred +from synapse.util.clock import Clock from synapse.util.constants import ONE_MINUTE_SECONDS if TYPE_CHECKING: @@ -66,10 +66,8 @@ class WorkerLocksHandler: """ def __init__(self, hs: "HomeServer") -> None: - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process - self._reactor = hs.get_reactor() + self.hs = hs # nb must be called this for @wrap_as_background_process + self._clock = hs.get_clock() self._store = hs.get_datastores().main self._clock = hs.get_clock() self._notifier = hs.get_notifier() @@ -98,7 +96,7 @@ class WorkerLocksHandler: """ lock = WaitingLock( - reactor=self._reactor, + clock=self._clock, store=self._store, handler=self, lock_name=lock_name, @@ -129,7 +127,7 @@ class WorkerLocksHandler: """ lock = WaitingLock( - reactor=self._reactor, + clock=self._clock, store=self._store, handler=self, lock_name=lock_name, @@ -160,7 +158,7 @@ class WorkerLocksHandler: lock = WaitingMultiLock( lock_names=lock_names, write=write, - reactor=self._reactor, + clock=self._clock, store=self._store, handler=self, ) @@ -197,7 +195,11 @@ class WorkerLocksHandler: if not deferred.called: deferred.callback(None) - self._clock.call_later(0, _wake_all_locks, locks) + self._clock.call_later( + 0, + _wake_all_locks, + locks, + ) @wrap_as_background_process("_cleanup_locks") async def _cleanup_locks(self) -> None: @@ -207,7 +209,7 @@ class WorkerLocksHandler: @attr.s(auto_attribs=True, eq=False) class WaitingLock: - reactor: IReactorTime + clock: Clock store: LockStore handler: WorkerLocksHandler lock_name: str @@ -246,10 +248,11 @@ class WaitingLock: # periodically wake up in case the lock was released but we # weren't notified. with PreserveLoggingContext(): + timeout = self._get_next_retry_interval() await timeout_deferred( deferred=self.deferred, - timeout=self._get_next_retry_interval(), - reactor=self.reactor, + timeout=timeout, + clock=self.clock, ) except Exception: pass @@ -290,7 +293,7 @@ class WaitingMultiLock: write: bool - reactor: IReactorTime + clock: Clock store: LockStore handler: WorkerLocksHandler @@ -323,10 +326,11 @@ class WaitingMultiLock: # periodically wake up in case the lock was released but we # weren't notified. with PreserveLoggingContext(): + timeout = self._get_next_retry_interval() await timeout_deferred( deferred=self.deferred, - timeout=self._get_next_retry_interval(), - reactor=self.reactor, + timeout=timeout, + clock=self.clock, ) except Exception: pass diff --git a/synapse/http/client.py b/synapse/http/client.py index bbb0efe8b5..370cdc3568 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -54,7 +54,6 @@ from twisted.internet.interfaces import ( IOpenSSLContextFactory, IReactorCore, IReactorPluggableNameResolver, - IReactorTime, IResolutionReceiver, ITCPTransport, ) @@ -88,6 +87,7 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.metrics import SERVER_NAME_LABEL from synapse.types import ISynapseReactor, StrSequence from synapse.util.async_helpers import timeout_deferred +from synapse.util.clock import Clock from synapse.util.json import json_decoder if TYPE_CHECKING: @@ -165,16 +165,17 @@ def _is_ip_blocked( _EPSILON = 0.00000001 -def _make_scheduler( - reactor: IReactorTime, -) -> Callable[[Callable[[], object]], IDelayedCall]: +def _make_scheduler(clock: Clock) -> Callable[[Callable[[], object]], IDelayedCall]: """Makes a schedular suitable for a Cooperator using the given reactor. (This is effectively just a copy from `twisted.internet.task`) """ def _scheduler(x: Callable[[], object]) -> IDelayedCall: - return reactor.callLater(_EPSILON, x) + return clock.call_later( + _EPSILON, + x, + ) return _scheduler @@ -367,7 +368,7 @@ class BaseHttpClient: # We use this for our body producers to ensure that they use the correct # reactor. - self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor())) + self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_clock())) async def request( self, @@ -436,9 +437,9 @@ class BaseHttpClient: # we use our own timeout mechanism rather than treq's as a workaround # for https://twistedmatrix.com/trac/ticket/9534. request_deferred = timeout_deferred( - request_deferred, - 60, - self.hs.get_reactor(), + deferred=request_deferred, + timeout=60, + clock=self.hs.get_clock(), ) # turn timeouts into RequestTimedOutErrors @@ -763,7 +764,11 @@ class BaseHttpClient: d = read_body_with_max_size(response, output_stream, max_size) # Ensure that the body is not read forever. - d = timeout_deferred(d, 30, self.hs.get_reactor()) + d = timeout_deferred( + deferred=d, + timeout=30, + clock=self.hs.get_clock(), + ) length = await make_deferred_yieldable(d) except BodyExceededMaxSize: @@ -957,9 +962,9 @@ class ReplicationClient(BaseHttpClient): # for https://twistedmatrix.com/trac/ticket/9534. # (Updated url https://github.com/twisted/twisted/issues/9534) request_deferred = timeout_deferred( - request_deferred, - 60, - self.hs.get_reactor(), + deferred=request_deferred, + timeout=60, + clock=self.hs.get_clock(), ) # turn timeouts into RequestTimedOutErrors diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 98826c9171..9d87514be0 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -67,6 +67,9 @@ class MatrixFederationAgent: Args: reactor: twisted reactor to use for underlying requests + clock: Internal `HomeServer` clock used to track delayed and looping calls. + Should be obtained from `hs.get_clock()`. + tls_client_options_factory: factory to use for fetching client tls options, or none to disable TLS. @@ -97,6 +100,7 @@ class MatrixFederationAgent: *, server_name: str, reactor: ISynapseReactor, + clock: Clock, tls_client_options_factory: Optional[FederationPolicyForHTTPS], user_agent: bytes, ip_allowlist: Optional[IPSet], @@ -109,6 +113,7 @@ class MatrixFederationAgent: Args: server_name: Our homeserver name (used to label metrics) (`hs.hostname`). reactor + clock: Should be the `hs` clock from `hs.get_clock()` tls_client_options_factory user_agent ip_allowlist @@ -124,7 +129,6 @@ class MatrixFederationAgent: # addresses, to prevent DNS rebinding. reactor = BlocklistingReactorWrapper(reactor, ip_allowlist, ip_blocklist) - self._clock = Clock(reactor, server_name=server_name) self._pool = HTTPConnectionPool(reactor) self._pool.retryAutomatically = False self._pool.maxPersistentPerHost = 5 @@ -147,6 +151,7 @@ class MatrixFederationAgent: _well_known_resolver = WellKnownResolver( server_name=server_name, reactor=reactor, + clock=clock, agent=BlocklistingAgentWrapper( ProxyAgent( reactor=reactor, diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 97bba8231a..2f52abcc03 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -90,6 +90,7 @@ class WellKnownResolver: self, server_name: str, reactor: ISynapseThreadlessReactor, + clock: Clock, agent: IAgent, user_agent: bytes, well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None, @@ -99,6 +100,7 @@ class WellKnownResolver: Args: server_name: Our homeserver name (used to label metrics) (`hs.hostname`). reactor + clock: Should be the `hs` clock from `hs.get_clock()` agent user_agent well_known_cache @@ -107,7 +109,7 @@ class WellKnownResolver: self.server_name = server_name self._reactor = reactor - self._clock = Clock(reactor, server_name=server_name) + self._clock = clock if well_known_cache is None: well_known_cache = TTLCache( diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c264bae6e5..4d72c72d01 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -90,6 +90,7 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.metrics import SERVER_NAME_LABEL from synapse.types import JsonDict from synapse.util.async_helpers import AwakenableSleeper, Linearizer, timeout_deferred +from synapse.util.clock import Clock from synapse.util.json import json_decoder from synapse.util.metrics import Measure from synapse.util.stringutils import parse_and_validate_server_name @@ -270,6 +271,7 @@ class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]): async def _handle_response( + clock: Clock, reactor: IReactorTime, timeout_sec: float, request: MatrixFederationRequest, @@ -299,7 +301,11 @@ async def _handle_response( check_content_type_is(response.headers, parser.CONTENT_TYPE) d = read_body_with_max_size(response, parser, max_response_size) - d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) + d = timeout_deferred( + deferred=d, + timeout=timeout_sec, + clock=clock, + ) length = await make_deferred_yieldable(d) @@ -411,6 +417,7 @@ class MatrixFederationHttpClient: self.server_name = hs.hostname self.reactor = hs.get_reactor() + self.clock = hs.get_clock() user_agent = hs.version_string if hs.config.server.user_agent_suffix: @@ -424,6 +431,7 @@ class MatrixFederationHttpClient: federation_agent: IAgent = MatrixFederationAgent( server_name=self.server_name, reactor=self.reactor, + clock=self.clock, tls_client_options_factory=tls_client_options_factory, user_agent=user_agent.encode("ascii"), ip_allowlist=hs.config.server.federation_ip_range_allowlist, @@ -457,7 +465,6 @@ class MatrixFederationHttpClient: ip_blocklist=hs.config.server.federation_ip_range_blocklist, ) - self.clock = hs.get_clock() self._store = hs.get_datastores().main self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout_seconds = hs.config.federation.client_timeout_ms / 1000 @@ -470,9 +477,9 @@ class MatrixFederationHttpClient: self.max_long_retries = hs.config.federation.max_long_retries self.max_short_retries = hs.config.federation.max_short_retries - self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor)) + self._cooperator = Cooperator(scheduler=_make_scheduler(self.clock)) - self._sleeper = AwakenableSleeper(self.reactor) + self._sleeper = AwakenableSleeper(self.clock) self._simple_http_client = SimpleHttpClient( hs, @@ -484,6 +491,10 @@ class MatrixFederationHttpClient: self.remote_download_linearizer = Linearizer( name="remote_download_linearizer", max_count=6, clock=self.clock ) + self._is_shutdown = False + + def shutdown(self) -> None: + self._is_shutdown = True def wake_destination(self, destination: str) -> None: """Called when the remote server may have come back online.""" @@ -629,6 +640,7 @@ class MatrixFederationHttpClient: limiter = await synapse.util.retryutils.get_retry_limiter( destination=request.destination, our_server_name=self.server_name, + hs=self.hs, clock=self.clock, store=self._store, backoff_on_404=backoff_on_404, @@ -675,7 +687,7 @@ class MatrixFederationHttpClient: (b"", b"", path_bytes, None, query_bytes, b"") ) - while True: + while not self._is_shutdown: try: json = request.get_json() if json: @@ -733,9 +745,9 @@ class MatrixFederationHttpClient: bodyProducer=producer, ) request_deferred = timeout_deferred( - request_deferred, + deferred=request_deferred, timeout=_sec_timeout, - reactor=self.reactor, + clock=self.clock, ) response = await make_deferred_yieldable(request_deferred) @@ -793,7 +805,9 @@ class MatrixFederationHttpClient: # Update transactions table? d = treq.content(response) d = timeout_deferred( - d, timeout=_sec_timeout, reactor=self.reactor + deferred=d, + timeout=_sec_timeout, + clock=self.clock, ) try: @@ -862,6 +876,15 @@ class MatrixFederationHttpClient: delay_seconds, ) + if self._is_shutdown: + # Immediately fail sending the request instead of starting a + # potentially long sleep after the server has requested + # shutdown. + # This is the code path followed when the + # `federation_transaction_transmission_loop` has been + # cancelled. + raise + # Sleep for the calculated delay, or wake up immediately # if we get notified that the server is back up. await self._sleeper.sleep( @@ -1074,6 +1097,7 @@ class MatrixFederationHttpClient: parser = cast(ByteParser[T], JsonParser()) body = await _handle_response( + self.clock, self.reactor, _sec_timeout, request, @@ -1152,7 +1176,13 @@ class MatrixFederationHttpClient: _sec_timeout = self.default_timeout_seconds body = await _handle_response( - self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() + self.clock, + self.reactor, + _sec_timeout, + request, + response, + start_ms, + parser=JsonParser(), ) return body @@ -1358,6 +1388,7 @@ class MatrixFederationHttpClient: parser = cast(ByteParser[T], JsonParser()) body = await _handle_response( + self.clock, self.reactor, _sec_timeout, request, @@ -1431,7 +1462,13 @@ class MatrixFederationHttpClient: _sec_timeout = self.default_timeout_seconds body = await _handle_response( - self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() + self.clock, + self.reactor, + _sec_timeout, + request, + response, + start_ms, + parser=JsonParser(), ) return body diff --git a/synapse/http/proxy.py b/synapse/http/proxy.py index 9b044f3b0a..fa17432984 100644 --- a/synapse/http/proxy.py +++ b/synapse/http/proxy.py @@ -161,12 +161,12 @@ class ProxyResource(_AsyncResource): bodyProducer=QuieterFileBodyProducer(request.content), ) request_deferred = timeout_deferred( - request_deferred, + deferred=request_deferred, # This should be set longer than the timeout in `MatrixFederationHttpClient` # so that it has enough time to complete and pass us the data before we give # up. timeout=90, - reactor=self.reactor, + clock=self._clock, ) response = await make_deferred_yieldable(request_deferred) diff --git a/synapse/http/server.py b/synapse/http/server.py index ce9d5630df..d5af8758ac 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -420,7 +420,14 @@ class DirectServeJsonResource(_AsyncResource): """ if clock is None: - clock = Clock( + # Ideally we wouldn't ignore the linter error here and instead enforce a + # required `Clock` be passed into the `__init__` function. + # However, this would change the function signature which is currently being + # exported to the module api. Since we don't want to break that api, we have + # to settle with ignoring the linter error here. + # As of the time of writing this, all Synapse internal usages of + # `DirectServeJsonResource` pass in the existing homeserver clock instance. + clock = Clock( # type: ignore[multiple-internal-clocks] cast(ISynapseThreadlessReactor, reactor), server_name="synapse_module_running_from_unknown_server", ) @@ -608,7 +615,14 @@ class DirectServeHtmlResource(_AsyncResource): Only optional for the Module API. """ if clock is None: - clock = Clock( + # Ideally we wouldn't ignore the linter error here and instead enforce a + # required `Clock` be passed into the `__init__` function. + # However, this would change the function signature which is currently being + # exported to the module api. Since we don't want to break that api, we have + # to settle with ignoring the linter error here. + # As of the time of writing this, all Synapse internal usages of + # `DirectServeHtmlResource` pass in the existing homeserver clock instance. + clock = Clock( # type: ignore[multiple-internal-clocks] cast(ISynapseThreadlessReactor, reactor), server_name="synapse_module_running_from_unknown_server", ) diff --git a/synapse/http/site.py b/synapse/http/site.py index 2c0c301c03..f4f326cfde 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -22,7 +22,7 @@ import contextlib import logging import time from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union import attr from zope.interface import implementer @@ -30,6 +30,7 @@ from zope.interface import implementer from twisted.internet.address import UNIXAddress from twisted.internet.defer import Deferred from twisted.internet.interfaces import IAddress +from twisted.internet.protocol import Protocol from twisted.python.failure import Failure from twisted.web.http import HTTPChannel from twisted.web.resource import IResource, Resource @@ -660,6 +661,70 @@ class _XForwardedForAddress: host: str +class SynapseProtocol(HTTPChannel): + """ + Synapse-specific twisted http Protocol. + + This is a small wrapper around the twisted HTTPChannel so we can track active + connections in order to close any outstanding connections on shutdown. + """ + + def __init__( + self, + site: "SynapseSite", + our_server_name: str, + max_request_body_size: int, + request_id_header: Optional[str], + request_class: type, + ): + super().__init__() + self.factory: SynapseSite = site + self.site = site + self.our_server_name = our_server_name + self.max_request_body_size = max_request_body_size + self.request_id_header = request_id_header + self.request_class = request_class + + def connectionMade(self) -> None: + """ + Called when a connection is made. + + This may be considered the initializer of the protocol, because + it is called when the connection is completed. + + Add the connection to the factory's connection list when it's established. + """ + super().connectionMade() + self.factory.addConnection(self) + + def connectionLost(self, reason: Failure) -> None: # type: ignore[override] + """ + Called when the connection is shut down. + + Clear any circular references here, and any external references to this + Protocol. The connection has been closed. In our case, we need to remove the + connection from the factory's connection list, when it's lost. + """ + super().connectionLost(reason) + self.factory.removeConnection(self) + + def requestFactory(self, http_channel: HTTPChannel, queued: bool) -> SynapseRequest: # type: ignore[override] + """ + A callable used to build `twisted.web.iweb.IRequest` objects. + + Use our own custom SynapseRequest type instead of the regular + twisted.web.server.Request. + """ + return self.request_class( + self, + self.factory, + our_server_name=self.our_server_name, + max_request_body_size=self.max_request_body_size, + queued=queued, + request_id_header=self.request_id_header, + ) + + class SynapseSite(ProxySite): """ Synapse-specific twisted http Site @@ -710,23 +775,44 @@ class SynapseSite(ProxySite): assert config.http_options is not None proxied = config.http_options.x_forwarded - request_class = XForwardedForRequest if proxied else SynapseRequest + self.request_class = XForwardedForRequest if proxied else SynapseRequest - request_id_header = config.http_options.request_id_header + self.request_id_header = config.http_options.request_id_header + self.max_request_body_size = max_request_body_size - def request_factory(channel: HTTPChannel, queued: bool) -> Request: - return request_class( - channel, - self, - our_server_name=self.server_name, - max_request_body_size=max_request_body_size, - queued=queued, - request_id_header=request_id_header, - ) - - self.requestFactory = request_factory # type: ignore self.access_logger = logging.getLogger(logger_name) self.server_version_string = server_version_string.encode("ascii") + self.connections: List[Protocol] = [] + + def buildProtocol(self, addr: IAddress) -> SynapseProtocol: + protocol = SynapseProtocol( + self, + self.server_name, + self.max_request_body_size, + self.request_id_header, + self.request_class, + ) + return protocol + + def addConnection(self, protocol: Protocol) -> None: + self.connections.append(protocol) + + def removeConnection(self, protocol: Protocol) -> None: + if protocol in self.connections: + self.connections.remove(protocol) + + def stopFactory(self) -> None: + super().stopFactory() + + # Shutdown any connections which are still active. + # These can be long lived HTTP connections which wouldn't normally be closed + # when calling `shutdown` on the respective `Port`. + # Closing the connections here is required for us to fully shutdown the + # `SynapseHomeServer` in order for it to be garbage collected. + for protocol in self.connections[:]: + if protocol.transport is not None: + protocol.transport.loseConnection() + self.connections.clear() def log(self, request: SynapseRequest) -> None: # type: ignore[override] pass diff --git a/synapse/media/_base.py b/synapse/media/_base.py index 15b28074fd..d3a9a66f5a 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py @@ -704,6 +704,7 @@ class ThreadedFileSender: def __init__(self, hs: "HomeServer") -> None: self.reactor = hs.get_reactor() + self.clock = hs.get_clock() self.thread_pool = hs.get_media_sender_thread_pool() self.file: Optional[BinaryIO] = None @@ -712,7 +713,7 @@ class ThreadedFileSender: # Signals if the thread should keep reading/sending data. Set means # continue, clear means pause. - self.wakeup_event = DeferredEvent(self.reactor) + self.wakeup_event = DeferredEvent(self.clock) # Signals if the thread should terminate, e.g. because the consumer has # gone away. diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 436d9b7e35..238dc6cb2f 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -67,7 +67,6 @@ from synapse.media.media_storage import ( from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia from synapse.types import UserID from synapse.util.async_helpers import Linearizer @@ -187,16 +186,14 @@ class MediaRepository: self.media_repository_callbacks = hs.get_module_api_callbacks().media_repository def _start_update_recently_accessed(self) -> Deferred: - return run_as_background_process( + return self.hs.run_as_background_process( "update_recently_accessed_media", - self.server_name, self._update_recently_accessed, ) def _start_apply_media_retention_rules(self) -> Deferred: - return run_as_background_process( + return self.hs.run_as_background_process( "apply_media_retention_rules", - self.server_name, self._apply_media_retention_rules, ) diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py index 81204913f7..1a82cc46e3 100644 --- a/synapse/media/url_previewer.py +++ b/synapse/media/url_previewer.py @@ -44,7 +44,6 @@ from synapse.media._base import FileInfo, get_filename_from_headers from synapse.media.media_storage import MediaStorage, SHA256TransparentIOWriter from synapse.media.oembed import OEmbedProvider from synapse.media.preview_html import decode_body, parse_html_to_open_graph -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict, UserID from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache @@ -167,6 +166,7 @@ class UrlPreviewer: media_storage: MediaStorage, ): self.clock = hs.get_clock() + self.hs = hs self.filepaths = media_repo.filepaths self.max_spider_size = hs.config.media.max_spider_size self.server_name = hs.hostname @@ -201,15 +201,14 @@ class UrlPreviewer: self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache( cache_name="url_previews", server_name=self.server_name, + hs=self.hs, clock=self.clock, # don't spider URLs more often than once an hour expiry_ms=ONE_HOUR, ) if self._worker_run_media_background_jobs: - self._cleaner_loop = self.clock.looping_call( - self._start_expire_url_cache_data, 10 * 1000 - ) + self.clock.looping_call(self._start_expire_url_cache_data, 10 * 1000) async def preview(self, url: str, user: UserID, ts: int) -> bytes: # the in-memory cache: @@ -739,8 +738,8 @@ class UrlPreviewer: return open_graph_result, oembed_response.author_name, expiration_ms def _start_expire_url_cache_data(self) -> Deferred: - return run_as_background_process( - "expire_url_cache_data", self.server_name, self._expire_url_cache_data + return self.hs.run_as_background_process( + "expire_url_cache_data", self._expire_url_cache_data ) async def _expire_url_cache_data(self) -> None: diff --git a/synapse/metrics/_gc.py b/synapse/metrics/_gc.py index e7783b05e6..1da871f18f 100644 --- a/synapse/metrics/_gc.py +++ b/synapse/metrics/_gc.py @@ -138,7 +138,9 @@ def install_gc_manager() -> None: gc_time.labels(i).observe(end - start) gc_unreachable.labels(i).set(unreachable) - gc_task = task.LoopingCall(_maybe_gc) + # We can ignore the lint here since this looping call does not hold a `HomeServer` + # reference so can be cleaned up by other means on shutdown. + gc_task = task.LoopingCall(_maybe_gc) # type: ignore[prefer-synapse-clock-looping-call] gc_task.start(0.1) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 93345b0e9d..6dc2cbe132 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -66,6 +66,8 @@ if TYPE_CHECKING: # Old versions don't have `LiteralString` from typing_extensions import LiteralString + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -397,11 +399,11 @@ def run_as_background_process( P = ParamSpec("P") -class HasServerName(Protocol): - server_name: str +class HasHomeServer(Protocol): + hs: "HomeServer" """ - The homeserver name that this cache is associated with (used to label the metric) - (`hs.hostname`). + The homeserver that this cache is associated with (used to label the metric and + track backgroun processes for clean shutdown). """ @@ -431,27 +433,22 @@ def wrap_as_background_process( """ def wrapper( - func: Callable[Concatenate[HasServerName, P], Awaitable[Optional[R]]], + func: Callable[Concatenate[HasHomeServer, P], Awaitable[Optional[R]]], ) -> Callable[P, "defer.Deferred[Optional[R]]"]: @wraps(func) def wrapped_func( - self: HasServerName, *args: P.args, **kwargs: P.kwargs + self: HasHomeServer, *args: P.args, **kwargs: P.kwargs ) -> "defer.Deferred[Optional[R]]": - assert self.server_name is not None, ( - "The `server_name` attribute must be set on the object where `@wrap_as_background_process` decorator is used." + assert self.hs is not None, ( + "The `hs` attribute must be set on the object where `@wrap_as_background_process` decorator is used." ) - return run_as_background_process( + return self.hs.run_as_background_process( desc, - self.server_name, func, self, *args, - # type-ignore: mypy is confusing kwargs with the bg_start_span kwarg. - # Argument 4 to "run_as_background_process" has incompatible type - # "**P.kwargs"; expected "bool" - # See https://github.com/python/mypy/issues/8862 - **kwargs, # type: ignore[arg-type] + **kwargs, ) # There are some shenanigans here, because we're decorating a method but diff --git a/synapse/metrics/common_usage_metrics.py b/synapse/metrics/common_usage_metrics.py index cd1c3c8649..43e0913d27 100644 --- a/synapse/metrics/common_usage_metrics.py +++ b/synapse/metrics/common_usage_metrics.py @@ -23,7 +23,6 @@ from typing import TYPE_CHECKING import attr from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process if TYPE_CHECKING: from synapse.server import HomeServer @@ -52,6 +51,7 @@ class CommonUsageMetricsManager: self.server_name = hs.hostname self._store = hs.get_datastores().main self._clock = hs.get_clock() + self._hs = hs async def get_metrics(self) -> CommonUsageMetrics: """Get the CommonUsageMetrics object. If no collection has happened yet, do it @@ -64,16 +64,14 @@ class CommonUsageMetricsManager: async def setup(self) -> None: """Keep the gauges for common usage metrics up to date.""" - run_as_background_process( + self._hs.run_as_background_process( desc="common_usage_metrics_update_gauges", - server_name=self.server_name, func=self._update_gauges, ) self._clock.looping_call( - run_as_background_process, + self._hs.run_as_background_process, 5 * 60 * 1000, desc="common_usage_metrics_update_gauges", - server_name=self.server_name, func=self._update_gauges, ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 7a419145e0..12a31dd2ab 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -275,7 +275,15 @@ def run_as_background_process( # function instead. stub_server_name = "synapse_module_running_from_unknown_server" - return _run_as_background_process( + # Ignore the linter error here. Since this is leveraging the + # `run_as_background_process` function directly and we don't want to break the + # module api, we need to keep the function signature the same. This means we don't + # have access to the running `HomeServer` and cannot track this background process + # for cleanup during shutdown. + # This is not an issue during runtime and is only potentially problematic if the + # application cares about being able to garbage collect `HomeServer` instances + # during runtime. + return _run_as_background_process( # type: ignore[untracked-background-process] desc, stub_server_name, func, @@ -1402,7 +1410,7 @@ class ModuleApi: if self._hs.config.worker.run_background_tasks or run_on_all_instances: self._clock.looping_call( - self.run_as_background_process, + self._hs.run_as_background_process, msec, desc, lambda: maybe_awaitable(f(*args, **kwargs)), @@ -1460,7 +1468,7 @@ class ModuleApi: return self._clock.call_later( # convert ms to seconds as needed by call_later. msec * 0.001, - self.run_as_background_process, + self._hs.run_as_background_process, desc, lambda: maybe_awaitable(f(*args, **kwargs)), ) @@ -1701,8 +1709,8 @@ class ModuleApi: Note that the returned Deferred does not follow the synapse logcontext rules. """ - return _run_as_background_process( - desc, self.server_name, func, *args, bg_start_span=bg_start_span, **kwargs + return self._hs.run_as_background_process( + desc, func, *args, bg_start_span=bg_start_span, **kwargs ) async def defer_to_thread( diff --git a/synapse/notifier.py b/synapse/notifier.py index e684df4866..9169f50c4d 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -676,9 +676,16 @@ class Notifier: # is a new token. listener = user_stream.new_listener(prev_token) listener = timeout_deferred( - listener, - (end_time - now) / 1000.0, - self.hs.get_reactor(), + deferred=listener, + timeout=(end_time - now) / 1000.0, + # We don't track these calls since they are constantly being + # overridden by new calls to /sync and they don't hold the + # `HomeServer` in memory on shutdown. It is safe to let them + # timeout of their own accord after shutting down since it + # won't delay shutdown and there won't be any adverse + # behaviour. + cancel_on_shutdown=False, + clock=self.hs.get_clock(), ) log_kv( diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 09ca14584a..1484bc8fc0 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -25,7 +25,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.interfaces import IDelayedCall -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException, ThrottleParams from synapse.push.mailer import Mailer from synapse.push.push_types import EmailReason @@ -118,7 +117,7 @@ class EmailPusher(Pusher): if self._is_processing: return - run_as_background_process("emailpush.process", self.server_name, self._process) + self.hs.run_as_background_process("emailpush.process", self._process) def _pause_processing(self) -> None: """Used by tests to temporarily pause processing of events. @@ -228,8 +227,10 @@ class EmailPusher(Pusher): self.timed_call = None if soonest_due_at is not None: - self.timed_call = self.hs.get_reactor().callLater( - self.seconds_until(soonest_due_at), self.on_timer + delay = self.seconds_until(soonest_due_at) + self.timed_call = self.hs.get_clock().call_later( + delay, + self.on_timer, ) async def save_last_stream_ordering_and_success( diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 5946a6e972..5cac5de8cb 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -32,7 +32,6 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.logging import opentracing from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.storage.databases.main.event_push_actions import HttpPushAction from synapse.types import JsonDict, JsonMapping @@ -182,8 +181,8 @@ class HttpPusher(Pusher): # We could check the receipts are actually m.read receipts here, # but currently that's the only type of receipt anyway... - run_as_background_process( - "http_pusher.on_new_receipts", self.server_name, self._update_badge + self.hs.run_as_background_process( + "http_pusher.on_new_receipts", self._update_badge ) async def _update_badge(self) -> None: @@ -219,7 +218,7 @@ class HttpPusher(Pusher): if self.failing_since and self.timed_call and self.timed_call.active(): return - run_as_background_process("httppush.process", self.server_name, self._process) + self.hs.run_as_background_process("httppush.process", self._process) async def _process(self) -> None: # we should never get here if we are already processing @@ -336,8 +335,9 @@ class HttpPusher(Pusher): ) else: logger.info("Push failed: delaying for %ds", self.backoff_delay) - self.timed_call = self.hs.get_reactor().callLater( - self.backoff_delay, self.on_timer + self.timed_call = self.hs.get_clock().call_later( + self.backoff_delay, + self.on_timer, ) self.backoff_delay = min( self.backoff_delay * 2, self.MAX_BACKOFF_SEC diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index d1f79ec999..977c55b683 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -27,7 +27,6 @@ from prometheus_client import Gauge from synapse.api.errors import Codes, SynapseError from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.push import Pusher, PusherConfig, PusherConfigException @@ -70,10 +69,8 @@ class PusherPool: """ def __init__(self, hs: "HomeServer"): - self.hs = hs - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.hs = hs # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname self.pusher_factory = PusherFactory(hs) self.store = self.hs.get_datastores().main self.clock = self.hs.get_clock() @@ -112,9 +109,7 @@ class PusherPool: if not self._should_start_pushers: logger.info("Not starting pushers because they are disabled in the config") return - run_as_background_process( - "start_pushers", self.server_name, self._start_pushers - ) + self.hs.run_as_background_process("start_pushers", self._start_pushers) async def add_or_update_pusher( self, diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index d96f5541f1..f2561bc0c5 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -32,7 +32,6 @@ from synapse.api.constants import EventTypes, Membership, ReceiptTypes from synapse.federation import send_queue from synapse.federation.sender import FederationSender from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.streams import ( AccountDataStream, DeviceListsStream, @@ -344,7 +343,9 @@ class ReplicationDataHandler: # to wedge here forever. deferred: "Deferred[None]" = Deferred() deferred = timeout_deferred( - deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor + deferred=deferred, + timeout=_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, + clock=self._clock, ) waiting_list = self._streams_to_waiters.setdefault( @@ -513,8 +514,8 @@ class FederationSenderHandler: # no need to queue up another task. return - run_as_background_process( - "_save_and_send_ack", self.server_name, self._save_and_send_ack + self._hs.run_as_background_process( + "_save_and_send_ack", self._save_and_send_ack ) async def _save_and_send_ack(self) -> None: diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index dd7e38dd78..4d0d3d44ab 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -41,7 +41,6 @@ from prometheus_client import Counter from twisted.internet.protocol import ReconnectingClientFactory from synapse.metrics import SERVER_NAME_LABEL, LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, Command, @@ -132,6 +131,7 @@ class ReplicationCommandHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname + self.hs = hs self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() self._store = hs.get_datastores().main @@ -361,9 +361,8 @@ class ReplicationCommandHandler: return # fire off a background process to start processing the queue. - run_as_background_process( + self.hs.run_as_background_process( "process-replication-data", - self.server_name, self._unsafe_process_queue, stream_name, ) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 25a7868cd7..bcfc65c2c0 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -42,7 +42,6 @@ from synapse.logging.context import PreserveLoggingContext from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.metrics.background_process_metrics import ( BackgroundProcessLoggingContext, - run_as_background_process, ) from synapse.replication.tcp.commands import ( VALID_CLIENT_COMMANDS, @@ -140,9 +139,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): max_line_buffer = 10000 def __init__( - self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler" + self, + hs: "HomeServer", + server_name: str, + clock: Clock, + handler: "ReplicationCommandHandler", ): self.server_name = server_name + self.hs = hs self.clock = clock self.command_handler = handler @@ -290,9 +294,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # if so. if isawaitable(res): - run_as_background_process( + self.hs.run_as_background_process( "replication-" + cmd.get_logcontext_id(), - self.server_name, lambda: res, ) @@ -470,9 +473,13 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS def __init__( - self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler" + self, + hs: "HomeServer", + server_name: str, + clock: Clock, + handler: "ReplicationCommandHandler", ): - super().__init__(server_name, clock, handler) + super().__init__(hs, server_name, clock, handler) self.server_name = server_name @@ -497,7 +504,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): clock: Clock, command_handler: "ReplicationCommandHandler", ): - super().__init__(server_name, clock, command_handler) + super().__init__(hs, server_name, clock, command_handler) self.client_name = client_name self.server_name = server_name diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 0b1be033b1..caffb2913e 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -40,7 +40,6 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import ( BackgroundProcessLoggingContext, - run_as_background_process, wrap_as_background_process, ) from synapse.replication.tcp.commands import ( @@ -109,6 +108,7 @@ class RedisSubscriber(SubscriberProtocol): """ server_name: str + hs: "HomeServer" synapse_handler: "ReplicationCommandHandler" synapse_stream_prefix: str synapse_channel_names: List[str] @@ -146,9 +146,7 @@ class RedisSubscriber(SubscriberProtocol): def connectionMade(self) -> None: logger.info("Connected to redis") super().connectionMade() - run_as_background_process( - "subscribe-replication", self.server_name, self._send_subscribe - ) + self.hs.run_as_background_process("subscribe-replication", self._send_subscribe) async def _send_subscribe(self) -> None: # it's important to make sure that we only send the REPLICATE command once we @@ -223,8 +221,8 @@ class RedisSubscriber(SubscriberProtocol): # if so. if isawaitable(res): - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), self.server_name, lambda: res + self.hs.run_as_background_process( + "replication-" + cmd.get_logcontext_id(), lambda: res ) def connectionLost(self, reason: Failure) -> None: # type: ignore[override] @@ -245,9 +243,8 @@ class RedisSubscriber(SubscriberProtocol): Args: cmd: The command to send """ - run_as_background_process( + self.hs.run_as_background_process( "send-cmd", - self.server_name, self._async_send_command, cmd, # We originally started tracing background processes to avoid `There was no @@ -317,9 +314,8 @@ class SynapseRedisFactory(RedisFactory): convertNumbers=convertNumbers, ) - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.hs = hs # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname hs.get_clock().looping_call(self._send_ping, 30 * 1000) @@ -397,6 +393,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): ) self.server_name = hs.hostname + self.hs = hs self.synapse_handler = hs.get_replication_command_handler() self.synapse_stream_prefix = hs.hostname self.synapse_channel_names = channel_names @@ -412,6 +409,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): # the base method does some other things than just instantiating the # protocol. p.server_name = self.server_name + p.hs = self.hs p.synapse_handler = self.synapse_handler p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection p.synapse_stream_prefix = self.synapse_stream_prefix diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index d800cfe6f6..ef72a0a532 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -30,7 +30,6 @@ from twisted.internet.interfaces import IAddress from twisted.internet.protocol import ServerFactory from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import PositionCommand from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol from synapse.replication.tcp.streams import EventsStream @@ -55,6 +54,7 @@ class ReplicationStreamProtocolFactory(ServerFactory): def __init__(self, hs: "HomeServer"): self.command_handler = hs.get_replication_command_handler() self.clock = hs.get_clock() + self.hs = hs self.server_name = hs.config.server.server_name # If we've created a `ReplicationStreamProtocolFactory` then we're @@ -69,7 +69,7 @@ class ReplicationStreamProtocolFactory(ServerFactory): def buildProtocol(self, addr: IAddress) -> ServerReplicationStreamProtocol: return ServerReplicationStreamProtocol( - self.server_name, self.clock, self.command_handler + self.hs, self.server_name, self.clock, self.command_handler ) @@ -82,6 +82,7 @@ class ReplicationStreamer: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname + self.hs = hs self.store = hs.get_datastores().main self.clock = hs.get_clock() self.notifier = hs.get_notifier() @@ -147,8 +148,8 @@ class ReplicationStreamer: logger.debug("Notifier poke loop already running") return - run_as_background_process( - "replication_notifier", self.server_name, self._run_notifier_loop + self.hs.run_as_background_process( + "replication_notifier", self._run_notifier_loop ) async def _run_notifier_loop(self) -> None: diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 25c15e5d48..87ac0a5ae1 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -77,6 +77,7 @@ STREAMS_MAP = { __all__ = [ "STREAMS_MAP", "Stream", + "EventsStream", "BackfillStream", "PresenceStream", "PresenceFederationStream", @@ -87,6 +88,7 @@ __all__ = [ "CachesStream", "DeviceListsStream", "ToDeviceStream", + "FederationStream", "AccountDataStream", "ThreadSubscriptionsStream", "UnPartialStatedRoomStream", diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 64deae7650..1084139df0 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -66,7 +66,6 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.state import CREATE_KEY, POWER_KEY @@ -1225,6 +1224,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.server_name = hs.hostname + self.hs = hs self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() self._store = hs.get_datastores().main @@ -1307,9 +1307,8 @@ class RoomRedactEventRestServlet(TransactionRestServlet): ) if with_relations: - run_as_background_process( + self.hs.run_as_background_process( "redact_related_events", - self.server_name, self._relation_handler.redact_events_related_to, requester=requester, event_id=event_id, diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index bb63b51599..0f3cc84dcc 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -126,6 +126,7 @@ class SyncRestServlet(RestServlet): self._json_filter_cache: LruCache[str, bool] = LruCache( max_size=1000, + clock=self.clock, cache_name="sync_valid_filter", server_name=self.server_name, ) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 1a57996aec..571ba2fa62 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -56,7 +56,7 @@ class HttpTransactionCache: ] = {} # Try to clean entries every 30 mins. This means entries will exist # for at *LEAST* 30 mins, and at *MOST* 60 mins. - self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) + self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable: """A helper function which returns a transaction key that can be used diff --git a/synapse/server.py b/synapse/server.py index edcab19d72..cc0d3a427b 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -28,10 +28,27 @@ import abc import functools import logging -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type, TypeVar, cast +from threading import Thread +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + cast, +) +from wsgiref.simple_server import WSGIServer +from attr import dataclass from typing_extensions import TypeAlias +from twisted.internet import defer +from twisted.internet.base import _SystemEventID from twisted.internet.interfaces import IOpenSSLContextFactory from twisted.internet.tcp import Port from twisted.python.threadpool import ThreadPool @@ -44,6 +61,7 @@ from synapse.api.auth.mas import MasDelegatedAuth from synapse.api.auth_blocking import AuthBlocking from synapse.api.filtering import Filtering from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter +from synapse.app._base import unregister_sighups from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.config.homeserver import HomeServerConfig @@ -133,6 +151,7 @@ from synapse.metrics import ( all_later_gauges_to_clean_up_on_shutdown, register_threadpool, ) +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager from synapse.module_api import ModuleApi from synapse.module_api.callbacks import ModuleApiCallbacks @@ -156,6 +175,7 @@ from synapse.storage.controllers import StorageControllers from synapse.streams.events import EventSources from synapse.synapse_rust.rendezvous import RendezvousHandler from synapse.types import DomainSpecificString, ISynapseReactor +from synapse.util.caches import CACHE_METRIC_REGISTRY from synapse.util.clock import Clock from synapse.util.distributor import Distributor from synapse.util.macaroons import MacaroonGenerator @@ -166,7 +186,9 @@ from synapse.util.task_scheduler import TaskScheduler logger = logging.getLogger(__name__) if TYPE_CHECKING: + # Old Python versions don't have `LiteralString` from txredisapi import ConnectionHandler + from typing_extensions import LiteralString from synapse.handlers.jwt import JwtHandler from synapse.handlers.oidc import OidcHandler @@ -196,6 +218,7 @@ if TYPE_CHECKING: T: TypeAlias = object F = TypeVar("F", bound=Callable[["HomeServer"], T]) +R = TypeVar("R") def cache_in_self(builder: F) -> F: @@ -219,7 +242,8 @@ def cache_in_self(builder: F) -> F: @functools.wraps(builder) def _get(self: "HomeServer") -> T: try: - return getattr(self, depname) + dep = getattr(self, depname) + return dep except AttributeError: pass @@ -239,6 +263,22 @@ def cache_in_self(builder: F) -> F: return cast(F, _get) +@dataclass +class ShutdownInfo: + """Information for callable functions called at time of shutdown. + + Attributes: + func: the object to call before shutdown. + trigger_id: an ID returned when registering this event trigger. + args: the arguments to call the function with. + kwargs: the keyword arguments to call the function with. + """ + + func: Callable[..., Any] + trigger_id: _SystemEventID + kwargs: Dict[str, object] + + class HomeServer(metaclass=abc.ABCMeta): """A basic homeserver object without lazy component builders. @@ -289,6 +329,7 @@ class HomeServer(metaclass=abc.ABCMeta): hostname : The hostname for the server. config: The full config for the homeserver. """ + if not reactor: from twisted.internet import reactor as _reactor @@ -300,6 +341,7 @@ class HomeServer(metaclass=abc.ABCMeta): self.signing_key = config.key.signing_key[0] self.config = config self._listening_services: List[Port] = [] + self._metrics_listeners: List[Tuple[WSGIServer, Thread]] = [] self.start_time: Optional[int] = None self._instance_id = random_string(5) @@ -315,6 +357,211 @@ class HomeServer(metaclass=abc.ABCMeta): # This attribute is set by the free function `refresh_certificate`. self.tls_server_context_factory: Optional[IOpenSSLContextFactory] = None + self._is_shutdown = False + self._async_shutdown_handlers: List[ShutdownInfo] = [] + self._sync_shutdown_handlers: List[ShutdownInfo] = [] + self._background_processes: set[defer.Deferred[Optional[Any]]] = set() + + def run_as_background_process( + self, + desc: "LiteralString", + func: Callable[..., Awaitable[Optional[R]]], + *args: Any, + **kwargs: Any, + ) -> "defer.Deferred[Optional[R]]": + """Run the given function in its own logcontext, with resource metrics + + This should be used to wrap processes which are fired off to run in the + background, instead of being associated with a particular request. + + It returns a Deferred which completes when the function completes, but it doesn't + follow the synapse logcontext rules, which makes it appropriate for passing to + clock.looping_call and friends (or for firing-and-forgetting in the middle of a + normal synapse async function). + + Because the returned Deferred does not follow the synapse logcontext rules, awaiting + the result of this function will result in the log context being cleared (bad). In + order to properly await the result of this function and maintain the current log + context, use `make_deferred_yieldable`. + + Args: + desc: a description for this background process type + server_name: The homeserver name that this background process is being run for + (this should be `hs.hostname`). + func: a function, which may return a Deferred or a coroutine + bg_start_span: Whether to start an opentracing span. Defaults to True. + Should only be disabled for processes that will not log to or tag + a span. + args: positional args for func + kwargs: keyword args for func + + Returns: + Deferred which returns the result of func, or `None` if func raises. + Note that the returned Deferred does not follow the synapse logcontext + rules. + """ + if self._is_shutdown: + raise Exception( + f"Cannot start background process. HomeServer has been shutdown {len(self._background_processes)} {len(self.get_clock()._looping_calls)} {len(self.get_clock()._call_id_to_delayed_call)}" + ) + + # Ignore linter error as this is the one location this should be called. + deferred = run_as_background_process(desc, self.hostname, func, *args, **kwargs) # type: ignore[untracked-background-process] + self._background_processes.add(deferred) + + def on_done(res: R) -> R: + try: + self._background_processes.remove(deferred) + except KeyError: + # If the background process isn't being tracked anymore we can just move on. + pass + return res + + deferred.addBoth(on_done) + return deferred + + async def shutdown(self) -> None: + """ + Cleanly stops all aspects of the HomeServer and removes any references that + have been handed out in order to allow the HomeServer object to be garbage + collected. + + You must ensure the HomeServer object to not be frozen in the garbage collector + in order for it to be cleaned up. By default, Synapse freezes the HomeServer + object in the garbage collector. + """ + + self._is_shutdown = True + + logger.info( + "Received shutdown request for %s (%s).", + self.hostname, + self.get_instance_id(), + ) + + # Unregister sighups first. If a shutdown was requested we shouldn't be responding + # to things like config changes. So it would be best to stop listening to these first. + unregister_sighups(self._instance_id) + + # TODO: It would be desireable to be able to report an error if the HomeServer + # object is frozen in the garbage collector as that would prevent it from being + # collected after being shutdown. + # In theory the following should work, but it doesn't seem to make a difference + # when I test it locally. + # + # if gc.is_tracked(self): + # logger.error("HomeServer object is tracked by garbage collection so cannot be fully cleaned up") + + for listener in self._listening_services: + # During unit tests, an incomplete `twisted.pair.testing._FakePort` is used + # for listeners so check listener type here to ensure shutdown procedure is + # only applied to actual `Port` instances. + if type(listener) is Port: + port_shutdown = listener.stopListening() + if port_shutdown is not None: + await port_shutdown + self._listening_services.clear() + + for server, thread in self._metrics_listeners: + server.shutdown() + thread.join() + self._metrics_listeners.clear() + + # TODO: Cleanup replication pieces + + self.get_keyring().shutdown() + + # Cleanup metrics associated with the homeserver + for later_gauge in all_later_gauges_to_clean_up_on_shutdown.values(): + later_gauge.unregister_hooks_for_homeserver_instance_id( + self.get_instance_id() + ) + + CACHE_METRIC_REGISTRY.unregister_hooks_for_homeserver( + self.config.server.server_name + ) + + for db in self.get_datastores().databases: + db.stop_background_updates() + + if self.should_send_federation(): + try: + self.get_federation_sender().shutdown() + except Exception: + pass + + for shutdown_handler in self._async_shutdown_handlers: + try: + self.get_reactor().removeSystemEventTrigger(shutdown_handler.trigger_id) + defer.ensureDeferred(shutdown_handler.func(**shutdown_handler.kwargs)) + except Exception as e: + logger.error("Error calling shutdown async handler: %s", e) + self._async_shutdown_handlers.clear() + + for shutdown_handler in self._sync_shutdown_handlers: + try: + self.get_reactor().removeSystemEventTrigger(shutdown_handler.trigger_id) + shutdown_handler.func(**shutdown_handler.kwargs) + except Exception as e: + logger.error("Error calling shutdown sync handler: %s", e) + self._sync_shutdown_handlers.clear() + + self.get_clock().shutdown() + + for background_process in list(self._background_processes): + try: + background_process.cancel() + except Exception: + pass + self._background_processes.clear() + + for db in self.get_datastores().databases: + db._db_pool.close() + + def register_async_shutdown_handler( + self, + *, + phase: str, + eventType: str, + shutdown_func: Callable[..., Any], + **kwargs: object, + ) -> None: + """ + Register a system event trigger with the HomeServer so it can be cleanly + removed when the HomeServer is shutdown. + """ + id = self.get_clock().add_system_event_trigger( + phase, + eventType, + shutdown_func, + **kwargs, + ) + self._async_shutdown_handlers.append( + ShutdownInfo(func=shutdown_func, trigger_id=id, kwargs=kwargs) + ) + + def register_sync_shutdown_handler( + self, + *, + phase: str, + eventType: str, + shutdown_func: Callable[..., Any], + **kwargs: object, + ) -> None: + """ + Register a system event trigger with the HomeServer so it can be cleanly + removed when the HomeServer is shutdown. + """ + id = self.get_clock().add_system_event_trigger( + phase, + eventType, + shutdown_func, + **kwargs, + ) + self._sync_shutdown_handlers.append( + ShutdownInfo(func=shutdown_func, trigger_id=id, kwargs=kwargs) + ) + def register_module_web_resource(self, path: str, resource: Resource) -> None: """Allows a module to register a web resource to be served at the given path. @@ -366,36 +613,25 @@ class HomeServer(metaclass=abc.ABCMeta): self.datastores = Databases(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") - def __del__(self) -> None: - """ - Called when an the homeserver is garbage collected. + # Register background tasks required by this server. This must be done + # somewhat manually due to the background tasks not being registered + # unless handlers are instantiated. + if self.config.worker.run_background_tasks: + self.start_background_tasks() - Make sure we actually do some clean-up, rather than leak data. - """ - self.cleanup() - - def cleanup(self) -> None: - """ - WIP: Clean-up any references to the homeserver and stop any running related - processes, timers, loops, replication stream, etc. - - This should be called wherever you care about the HomeServer being completely - garbage collected like in tests. It's not necessary to call if you plan to just - shut down the whole Python process anyway. - - Can be called multiple times. - """ - logger.info("Received cleanup request for %s.", self.hostname) - - # TODO: Stop background processes, timers, loops, replication stream, etc. - - # Cleanup metrics associated with the homeserver - for later_gauge in all_later_gauges_to_clean_up_on_shutdown.values(): - later_gauge.unregister_hooks_for_homeserver_instance_id( - self.get_instance_id() - ) - - logger.info("Cleanup complete for %s.", self.hostname) + # def __del__(self) -> None: + # """ + # Called when an the homeserver is garbage collected. + # + # Make sure we actually do some clean-up, rather than leak data. + # """ + # + # # NOTE: This is a chicken and egg problem. + # # __del__ will never be called since the HomeServer cannot be garbage collected + # # until the shutdown function has been called. So it makes no sense to call + # # shutdown inside of __del__, even though that is a logical place to assume it + # # should be called. + # self.shutdown() def start_listening(self) -> None: # noqa: B027 (no-op by design) """Start the HTTP, manhole, metrics, etc listeners @@ -442,7 +678,8 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_clock(self) -> Clock: - return Clock(self._reactor, server_name=self.hostname) + # Ignore the linter error since this is the one place the `Clock` should be created. + return Clock(self._reactor, server_name=self.hostname) # type: ignore[multiple-internal-clocks] def get_datastores(self) -> Databases: if not self.datastores: @@ -452,7 +689,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_distributor(self) -> Distributor: - return Distributor(server_name=self.hostname) + return Distributor(hs=self) @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: @@ -1007,8 +1244,10 @@ class HomeServer(metaclass=abc.ABCMeta): ) media_threadpool.start() - self.get_clock().add_system_event_trigger( - "during", "shutdown", media_threadpool.stop + self.register_sync_shutdown_handler( + phase="during", + eventType="shutdown", + shutdown_func=media_threadpool.stop, ) # Register the threadpool with our metrics. diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 19f86b5a56..73cf4091eb 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -36,6 +36,7 @@ SERVER_NOTICE_ROOM_TAG = "m.server_notice" class ServerNoticesManager: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname # nb must be called this for @cached + self.clock = hs.get_clock() # nb must be called this for @cached self._store = hs.get_datastores().main self._config = hs.config self._account_data_handler = hs.get_account_data_handler() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index dd8d7135ba..394dc72fa6 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -651,6 +651,7 @@ class StateResolutionHandler: ExpiringCache( cache_name="state_cache", server_name=self.server_name, + hs=hs, clock=self.clock, max_len=100000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f214f55897..1fddcc0799 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -56,7 +56,7 @@ class SQLBaseStore(metaclass=ABCMeta): ): self.hs = hs self.server_name = hs.hostname # nb must be called this for @cached - self._clock = hs.get_clock() + self.clock = hs.get_clock() # nb must be called this for @cached self.database_engine = database.engine self.db_pool = database diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 9aa9e51aeb..e3e793d5f5 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -41,7 +41,6 @@ from typing import ( import attr from synapse._pydantic_compat import BaseModel -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection, Cursor from synapse.types import JsonDict, StrCollection @@ -285,6 +284,13 @@ class BackgroundUpdater: self.sleep_duration_ms = hs.config.background_updates.sleep_duration_ms self.sleep_enabled = hs.config.background_updates.sleep_enabled + def shutdown(self) -> None: + """ + Stop any further background updates from happening. + """ + self.enabled = False + self._background_update_handlers.clear() + def get_status(self) -> UpdaterStatus: """An integer summarising the updater status. Used as a metric.""" if self._aborted: @@ -396,9 +402,8 @@ class BackgroundUpdater: # if we start a new background update, not all updates are done. self._all_done = False sleep = self.sleep_enabled - run_as_background_process( + self.hs.run_as_background_process( "background_updates", - self.server_name, self.run_background_updates, sleep, ) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 120934af57..646e2cf115 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -62,7 +62,6 @@ from synapse.logging.opentracing import ( trace, ) from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState @@ -195,6 +194,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): def __init__( self, + hs: "HomeServer", server_name: str, per_item_callback: Callable[ [str, _EventPersistQueueTask], @@ -207,6 +207,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): and its result will be returned via the Deferreds returned from add_to_queue. """ self.server_name = server_name + self.hs = hs self._event_persist_queues: Dict[str, Deque[_EventPersistQueueItem]] = {} self._currently_persisting_rooms: Set[str] = set() self._per_item_callback = per_item_callback @@ -311,7 +312,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): self._currently_persisting_rooms.discard(room_id) # set handle_queue_loop off in the background - run_as_background_process("persist_events", self.server_name, handle_queue_loop) + self.hs.run_as_background_process("persist_events", handle_queue_loop) def _get_drainining_queue( self, room_id: str @@ -354,7 +355,7 @@ class EventsPersistenceStorageController: self._instance_name = hs.get_instance_name() self.is_mine_id = hs.is_mine_id self._event_persist_queue = _EventPeristenceQueue( - self.server_name, self._process_event_persist_queue_task + hs, self.server_name, self._process_event_persist_queue_task ) self._state_resolution_handler = hs.get_state_resolution_handler() self._state_controller = state_controller diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py index 14b37ac543..ded9cb0567 100644 --- a/synapse/storage/controllers/purge_events.py +++ b/synapse/storage/controllers/purge_events.py @@ -46,9 +46,8 @@ class PurgeEventsStorageController: """High level interface for purging rooms and event history.""" def __init__(self, hs: "HomeServer", stores: Databases): - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.hs = hs # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname self.stores = stores if hs.config.worker.run_background_tasks: diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 66f3289d86..76978402b9 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -69,8 +69,8 @@ class StateStorageController: def __init__(self, hs: "HomeServer", stores: "Databases"): self.server_name = hs.hostname # nb must be called this for @cached + self.clock = hs.get_clock() self._is_mine_id = hs.is_mine_id - self._clock = hs.get_clock() self.stores = stores self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main) @@ -78,7 +78,7 @@ class StateStorageController: # Used by `_get_joined_hosts` to ensure only one thing mutates the cache # at a time. Keyed by room_id. self._joined_host_linearizer = Linearizer( - name="_JoinedHostsCache", clock=self._clock + name="_JoinedHostsCache", clock=self.clock ) def notify_event_un_partial_stated(self, event_id: str) -> None: @@ -817,9 +817,7 @@ class StateStorageController: state_group = object() assert state_group is not None - with Measure( - self._clock, name="get_joined_hosts", server_name=self.server_name - ): + with Measure(self.clock, name="get_joined_hosts", server_name=self.server_name): return await self._get_joined_hosts( room_id, state_group, state_entry=state_entry ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 249a0a933c..a4b2b26795 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -62,7 +62,6 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.metrics import SERVER_NAME_LABEL, register_threadpool -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor, SQLQueryParameters @@ -638,12 +637,17 @@ class DatabasePool: # background updates of tables that aren't safe to update. self._clock.call_later( 0.0, - run_as_background_process, + self.hs.run_as_background_process, "upsert_safety_check", - self.server_name, self._check_safe_to_upsert, ) + def stop_background_updates(self) -> None: + """ + Stops the database from running any further background updates. + """ + self.updates.shutdown() + def name(self) -> str: "Return the name of this database" return self._database_config.name @@ -681,9 +685,8 @@ class DatabasePool: if background_update_names: self._clock.call_later( 15.0, - run_as_background_process, + self.hs.run_as_background_process, "upsert_safety_check", - self.server_name, self._check_safe_to_upsert, ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index cad26fefa4..674c6b921e 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -751,7 +751,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, - "invalidation_ts": self._clock.time_msec(), + "invalidation_ts": self.clock.time_msec(), }, ) @@ -778,7 +778,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): assert self._cache_id_gen is not None stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples)) - ts = self._clock.time_msec() + ts = self.clock.time_msec() txn.call_after(self.hs.get_notifier().on_new_replication_data) self.db_pool.simple_insert_many_txn( txn, @@ -830,7 +830,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): next_interval = REGULAR_CLEANUP_INTERVAL_MS self.hs.get_clock().call_later( - next_interval / 1000, self._clean_up_cache_invalidation_wrapper + next_interval / 1000, + self._clean_up_cache_invalidation_wrapper, ) async def _clean_up_batch_of_old_cache_invalidations( diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 3f9f482add..45cfe97dba 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -77,7 +77,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase return before_ts = ( - self._clock.time_msec() - self.hs.config.server.redaction_retention_period + self.clock.time_msec() - self.hs.config.server.redaction_retention_period ) # We fetch all redactions that: diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index c7a330cc83..dc6ab99a6c 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -438,10 +438,11 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke cache_name="client_ip_last_seen", server_name=self.server_name, max_size=50000, + clock=hs.get_clock(), ) if hs.config.worker.run_background_tasks and self.user_ips_max_age: - self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) + self.clock.looping_call(self._prune_old_user_ips, 5 * 1000) if self._update_on_this_worker: # This is the designated worker that can write to the client IP @@ -452,11 +453,11 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke Tuple[str, str, str], Tuple[str, Optional[str], int] ] = {} - self._client_ip_looper = self._clock.looping_call( - self._update_client_ips_batch, 5 * 1000 - ) - self.hs.get_clock().add_system_event_trigger( - "before", "shutdown", self._update_client_ips_batch + self.clock.looping_call(self._update_client_ips_batch, 5 * 1000) + hs.register_async_shutdown_handler( + phase="before", + eventType="shutdown", + shutdown_func=self._update_client_ips_batch, ) @wrap_as_background_process("prune_old_user_ips") @@ -492,7 +493,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke ) """ - timestamp = self._clock.time_msec() - self.user_ips_max_age + timestamp = self.clock.time_msec() - self.user_ips_max_age def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None: txn.execute(sql, (timestamp,)) @@ -628,7 +629,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke return if not now: - now = int(self._clock.time_msec()) + now = int(self.clock.time_msec()) key = (user_id, access_token, ip) try: diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index f6f3c94a0d..a66e11f738 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -96,7 +96,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): ] = ExpiringCache( cache_name="last_device_delete_cache", server_name=self.server_name, - clock=self._clock, + hs=hs, + clock=self.clock, max_len=10000, expiry_ms=30 * 60 * 1000, ) @@ -154,7 +155,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) if hs.config.worker.run_background_tasks: - self._clock.looping_call( + self.clock.looping_call( run_as_background_process, DEVICE_FEDERATION_INBOX_CLEANUP_INTERVAL_MS, "_delete_old_federation_inbox_rows", @@ -826,7 +827,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) async with self._to_device_msg_id_gen.get_next() as stream_id: - now_ms = self._clock.time_msec() + now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) @@ -881,7 +882,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) async with self._to_device_msg_id_gen.get_next() as stream_id: - now_ms = self._clock.time_msec() + now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, @@ -1002,7 +1003,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): # We delete at most 100 rows that are older than # DEVICE_FEDERATION_INBOX_CLEANUP_DELAY_MS delete_before_ts = ( - self._clock.time_msec() - DEVICE_FEDERATION_INBOX_CLEANUP_DELAY_MS + self.clock.time_msec() - DEVICE_FEDERATION_INBOX_CLEANUP_DELAY_MS ) sql = """ WITH to_delete AS ( @@ -1032,7 +1033,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): # We sleep a bit so that we don't hammer the database in a tight # loop first time we run this. - await self._clock.sleep(1) + await self.clock.sleep(1) async def get_devices_with_messages( self, user_id: str, device_ids: StrCollection diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index fc1e1c73f1..d4b9ce0ea0 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -195,7 +195,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) if hs.config.worker.run_background_tasks: - self._clock.looping_call( + self.clock.looping_call( self._prune_old_outbound_device_pokes, 60 * 60 * 1000 ) @@ -1390,7 +1390,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): table="device_lists_remote_resync", keyvalues={"user_id": user_id}, values={}, - insertion_values={"added_ts": self._clock.time_msec()}, + insertion_values={"added_ts": self.clock.time_msec()}, ) await self.db_pool.runInteraction( @@ -1601,7 +1601,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): that user when the destination comes back. It doesn't matter which device we keep. """ - yesterday = self._clock.time_msec() - prune_age + yesterday = self.clock.time_msec() - prune_age def _prune_txn(txn: LoggingTransaction) -> None: # look for (user, destination) pairs which have an update older than @@ -2086,7 +2086,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): stream_id, ) - now = self._clock.time_msec() + now = self.clock.time_msec() encoded_context = json_encoder.encode(context) mark_sent = not self.hs.is_mine_id(user_id) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 2e9f62075a..2d3d0c0036 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1564,7 +1564,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker DELETE FROM e2e_one_time_keys_json WHERE {clause} AND ts_added_ms < ? AND length(key_id) = 6 """ - args.append(self._clock.time_msec() - (7 * 24 * 3600 * 1000)) + args.append(self.clock.time_msec() - (7 * 24 * 3600 * 1000)) txn.execute(sql, args) return users, txn.rowcount @@ -1585,7 +1585,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker None, if there is no such key. Otherwise, the timestamp before which replacement is allowed without UIA. """ - timestamp = self._clock.time_msec() + duration_ms + timestamp = self.clock.time_msec() + duration_ms def impl(txn: LoggingTransaction) -> Optional[int]: txn.execute( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 5c9bd2e848..d77420ff47 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -167,6 +167,7 @@ class EventFederationWorkerStore( # Cache of event ID to list of auth event IDs and their depths. self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache( max_size=500000, + clock=self.hs.get_clock(), server_name=self.server_name, cache_name="_event_auth_cache", size_callback=len, @@ -176,7 +177,7 @@ class EventFederationWorkerStore( # index. self.tests_allow_no_chain_cover_index = True - self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000) + self.clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000) if isinstance(self.database_engine, PostgresEngine): self.db_pool.updates.register_background_validate_constraint_and_delete_rows( @@ -1328,7 +1329,7 @@ class EventFederationWorkerStore( ( room_id, current_depth, - self._clock.time_msec(), + self.clock.time_msec(), BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS, limit, @@ -1841,7 +1842,7 @@ class EventFederationWorkerStore( last_cause=EXCLUDED.last_cause; """ - txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause)) + txn.execute(sql, (room_id, event_id, 1, self.clock.time_msec(), cause)) @trace async def get_event_ids_with_failed_pull_attempts( @@ -1905,7 +1906,7 @@ class EventFederationWorkerStore( ), ) - current_time = self._clock.time_msec() + current_time = self.clock.time_msec() event_ids_with_backoff = {} for event_id, last_attempt_ts, num_attempts in event_failed_pull_attempts: @@ -2025,7 +2026,7 @@ class EventFederationWorkerStore( values={}, insertion_values={ "room_id": event.room_id, - "received_ts": self._clock.time_msec(), + "received_ts": self.clock.time_msec(), "event_json": json_encoder.encode(event.get_dict()), "internal_metadata": json_encoder.encode( event.internal_metadata.get_dict() @@ -2299,7 +2300,7 @@ class EventFederationWorkerStore( # If there is nothing in the staging area default it to 0. age = 0 if received_ts is not None: - age = self._clock.time_msec() - received_ts + age = self.clock.time_msec() - received_ts return count, age diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 4db0230421..ec26aedc6b 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -95,6 +95,8 @@ from typing import ( import attr +from twisted.internet.task import LoopingCall + from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -254,6 +256,8 @@ def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, st class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBaseStore): + _background_tasks: List[LoopingCall] = [] + def __init__( self, database: DatabasePool, @@ -263,7 +267,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas super().__init__(database, db_conn, hs) # Track when the process started. - self._started_ts = self._clock.time_msec() + self._started_ts = self.clock.time_msec() # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago: Optional[int] = None @@ -273,18 +277,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas self._find_stream_orderings_for_times_txn(cur) cur.close() - self.find_stream_orderings_looping_call = self._clock.looping_call( - self._find_stream_orderings_for_times, 10 * 60 * 1000 - ) + self.clock.looping_call(self._find_stream_orderings_for_times, 10 * 60 * 1000) self._rotate_count = 10000 self._doing_notif_rotation = False if hs.config.worker.run_background_tasks: - self._rotate_notif_loop = self._clock.looping_call( - self._rotate_notifs, 30 * 1000 - ) + self.clock.looping_call(self._rotate_notifs, 30 * 1000) - self._clear_old_staging_loop = self._clock.looping_call( + self.clock.looping_call( self._clear_old_push_actions_staging, 30 * 60 * 1000 ) @@ -1190,7 +1190,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas is_highlight, # highlight column int(count_as_unread), # unread column thread_id, # thread_id column - self._clock.time_msec(), # inserted_ts column + self.clock.time_msec(), # inserted_ts column ) await self.db_pool.simple_insert_many( @@ -1241,14 +1241,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def _find_stream_orderings_for_times_txn(self, txn: LoggingTransaction) -> None: logger.info("Searching for stream ordering 1 month ago") self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 + txn, self.clock.time_msec() - 30 * 24 * 60 * 60 * 1000 ) logger.info( "Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago ) logger.info("Searching for stream ordering 1 day ago") self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 + txn, self.clock.time_msec() - 24 * 60 * 60 * 1000 ) logger.info( "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago @@ -1787,7 +1787,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # We delete anything more than an hour old, on the assumption that we'll # never take more than an hour to persist an event. - delete_before_ts = self._clock.time_msec() - 60 * 60 * 1000 + delete_before_ts = self.clock.time_msec() - 60 * 60 * 1000 if self._started_ts > delete_before_ts: # We need to wait for at least an hour before we started deleting, @@ -1824,7 +1824,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return # We sleep to ensure that we don't overwhelm the DB. - await self._clock.sleep(1.0) + await self.clock.sleep(1.0) async def get_push_actions_for_user( self, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 0a0102ee64..37dd8e48d5 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -730,7 +730,7 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS WHERE ? <= event_id AND event_id <= ? """ - txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) + txn.execute(sql, (self.clock.time_msec(), last_event_id, upper_event_id)) self.db_pool.updates._background_update_progress_txn( txn, "redactions_received_ts", {"last_event_id": upper_event_id} diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 31e2312211..4f9a1a4f78 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -70,7 +70,6 @@ from synapse.logging.opentracing import ( ) from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.replication.tcp.streams import BackfillStream, UnPartialStatedEventStream @@ -282,13 +281,14 @@ class EventsWorkerStore(SQLBaseStore): if hs.config.worker.run_background_tasks: # We periodically clean out old transaction ID mappings - self._clock.looping_call( + self.clock.looping_call( self._cleanup_old_transaction_ids, 5 * 60 * 1000, ) self._get_event_cache: AsyncLruCache[Tuple[str], EventCacheEntry] = ( AsyncLruCache( + clock=hs.get_clock(), server_name=self.server_name, cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, @@ -1154,9 +1154,7 @@ class EventsWorkerStore(SQLBaseStore): should_start = False if should_start: - run_as_background_process( - "fetch_events", self.server_name, self._fetch_thread - ) + self.hs.run_as_background_process("fetch_events", self._fetch_thread) async def _fetch_thread(self) -> None: """Services requests for events from `_event_fetch_list`.""" @@ -1276,7 +1274,7 @@ class EventsWorkerStore(SQLBaseStore): were not part of this request. """ with Measure( - self._clock, name="_fetch_event_list", server_name=self.server_name + self.clock, name="_fetch_event_list", server_name=self.server_name ): try: events_to_fetch = { @@ -2278,7 +2276,7 @@ class EventsWorkerStore(SQLBaseStore): """Cleans out transaction id mappings older than 24hrs.""" def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: - one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 + one_day_ago = self.clock.time_msec() - 24 * 60 * 60 * 1000 sql = """ DELETE FROM event_txn_id_device_id WHERE inserted_ts < ? @@ -2633,7 +2631,7 @@ class EventsWorkerStore(SQLBaseStore): keyvalues={"event_id": event_id}, values={ "reason": rejection_reason, - "last_check": self._clock.time_msec(), + "last_check": self.clock.time_msec(), }, ) self.db_pool.simple_update_txn( diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index d0e4a91b59..e2b15eaf6a 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -28,7 +28,6 @@ from twisted.internet import defer from twisted.internet.task import LoopingCall from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.storage._base import SQLBaseStore @@ -99,15 +98,15 @@ class LockStore(SQLBaseStore): # lead to a race, as we may drop the lock while we are still processing. # However, a) it should be a small window, b) the lock is best effort # anyway and c) we want to really avoid leaking locks when we restart. - hs.get_clock().add_system_event_trigger( - "before", - "shutdown", - self._on_shutdown, + hs.register_async_shutdown_handler( + phase="before", + eventType="shutdown", + shutdown_func=self._on_shutdown, ) self._acquiring_locks: Set[Tuple[str, str]] = set() - self._clock.looping_call( + self.clock.looping_call( self._reap_stale_read_write_locks, _LOCK_TIMEOUT_MS / 10.0 ) @@ -153,7 +152,7 @@ class LockStore(SQLBaseStore): if lock and await lock.is_still_valid(): return None - now = self._clock.time_msec() + now = self.clock.time_msec() token = random_string(6) def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool: @@ -202,7 +201,8 @@ class LockStore(SQLBaseStore): lock = Lock( self.server_name, self._reactor, - self._clock, + self.hs, + self.clock, self, read_write=False, lock_name=lock_name, @@ -251,7 +251,7 @@ class LockStore(SQLBaseStore): # constraints. If it doesn't then we have acquired the lock, # otherwise we haven't. - now = self._clock.time_msec() + now = self.clock.time_msec() token = random_string(6) self.db_pool.simple_insert_txn( @@ -270,7 +270,8 @@ class LockStore(SQLBaseStore): lock = Lock( self.server_name, self._reactor, - self._clock, + self.hs, + self.clock, self, read_write=True, lock_name=lock_name, @@ -338,7 +339,7 @@ class LockStore(SQLBaseStore): """ def reap_stale_read_write_locks_txn(txn: LoggingTransaction) -> None: - txn.execute(delete_sql, (self._clock.time_msec() - _LOCK_TIMEOUT_MS,)) + txn.execute(delete_sql, (self.clock.time_msec() - _LOCK_TIMEOUT_MS,)) if txn.rowcount: logger.info("Reaped %d stale locks", txn.rowcount) @@ -374,6 +375,7 @@ class Lock: self, server_name: str, reactor: ISynapseReactor, + hs: "HomeServer", clock: Clock, store: LockStore, read_write: bool, @@ -387,6 +389,7 @@ class Lock: """ self._server_name = server_name self._reactor = reactor + self._hs = hs self._clock = clock self._store = store self._read_write = read_write @@ -410,6 +413,7 @@ class Lock: _RENEWAL_INTERVAL_MS, self._server_name, self._store, + self._hs, self._clock, self._read_write, self._lock_name, @@ -421,6 +425,7 @@ class Lock: def _renew( server_name: str, store: LockStore, + hs: "HomeServer", clock: Clock, read_write: bool, lock_name: str, @@ -457,9 +462,8 @@ class Lock: desc="renew_lock", ) - return run_as_background_process( + return hs.run_as_background_process( "Lock._renew", - server_name, _internal_renew, store, clock, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index f726846e57..b8bd0042d7 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -565,7 +565,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): sql, ( user_id.to_string(), - self._clock.time_msec() - self.unused_expiration_time, + self.clock.time_msec() - self.unused_expiration_time, ), ) row = txn.fetchone() @@ -1059,7 +1059,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn: LoggingTransaction, ) -> int: # Calculate the timestamp for the start of the time period - start_ts = self._clock.time_msec() - time_period_ms + start_ts = self.clock.time_msec() - time_period_ms txn.execute(sql, (user_id, start_ts)) row = txn.fetchone() if row is None: diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index a3467bff3d..49411ed034 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -78,7 +78,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): # Read the extrems every 60 minutes if hs.config.worker.run_background_tasks: - self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000) + self.clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000) # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() @@ -224,7 +224,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): """ Counts the number of users who used this homeserver in the last 24 hours. """ - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) + yesterday = int(self.clock.time_msec()) - (1000 * 60 * 60 * 24) return await self.db_pool.runInteraction( "count_daily_users", self._count_users, yesterday ) @@ -236,7 +236,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): from the mau figure in synapse.storage.monthly_active_users which, amongst other things, includes a 3 day grace period before a user counts. """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + thirty_days_ago = int(self.clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) return await self.db_pool.runInteraction( "count_monthly_users", self._count_users, thirty_days_ago ) @@ -281,7 +281,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]: thirty_days_in_secs = 86400 * 30 - now = int(self._clock.time()) + now = int(self.clock.time()) sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs one_day_from_now_in_secs = now + 86400 @@ -389,7 +389,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): """ Returns millisecond unixtime for start of UTC day. """ - now = time.gmtime(self._clock.time()) + now = time.gmtime(self.clock.time()) today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) return today_start * 1000 @@ -403,7 +403,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): logger.info("Calling _generate_user_daily_visits") today_start = self._get_start_of_day() a_day_in_milliseconds = 24 * 60 * 60 * 1000 - now = self._clock.time_msec() + now = self.clock.time_msec() # A note on user_agent. Technically a given device can have multiple # user agents, so we need to decide which one to pick. We could have diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index f5a6b98be7..86744f616c 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -49,7 +49,6 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): hs: "HomeServer", ): super().__init__(database, db_conn, hs) - self._clock = hs.get_clock() self.hs = hs if hs.config.redis.redis_enabled: @@ -226,7 +225,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): reserved_users: reserved users to preserve """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + thirty_days_ago = int(self.clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) in_clause, in_clause_args = make_in_list_sql_clause( self.database_engine, "user_id", reserved_users @@ -328,7 +327,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): txn, table="monthly_active_users", keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, + values={"timestamp": int(self.clock.time_msec())}, ) else: logger.warning("mau limit reserved threepid %s not found in db", tp) @@ -391,7 +390,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): txn, table="monthly_active_users", keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, + values={"timestamp": int(self.clock.time_msec())}, ) self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index ff4eb9acb2..f1dbf68971 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -1073,7 +1073,7 @@ class ReceiptsWorkerStore(SQLBaseStore): if event_ts is None: return None - now = self._clock.time_msec() + now = self.clock.time_msec() logger.debug( "Receipt %s for event %s in %s (%i ms old)", receipt_type, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 117444e7b7..906d1a91f6 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -212,7 +212,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): ) if hs.config.worker.run_background_tasks: - self._clock.call_later( + self.clock.call_later( 0.0, self._set_expiration_date_when_missing, ) @@ -226,7 +226,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): # Create a background job for culling expired 3PID validity tokens if hs.config.worker.run_background_tasks: - self._clock.looping_call( + self.clock.looping_call( self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS ) @@ -298,7 +298,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): ) -> None: user_id_obj = UserID.from_string(user_id) - now = int(self._clock.time()) + now = int(self.clock.time()) user_approved = approved or not self._require_approval @@ -457,7 +457,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): if not info: return False - now = self._clock.time_msec() + now = self.clock.time_msec() days = self.config.server.mau_appservice_trial_days.get( info.appservice_id, self.config.server.mau_trial_days ) @@ -640,7 +640,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): return await self.db_pool.runInteraction( "get_users_expiring_soon", select_users_txn, - self._clock.time_msec(), + self.clock.time_msec(), self.config.account_validity.account_validity_renew_at, ) @@ -1084,7 +1084,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): """ def _count_daily_user_type(txn: LoggingTransaction) -> Dict[str, int]: - yesterday = int(self._clock.time()) - (60 * 60 * 24) + yesterday = int(self.clock.time()) - (60 * 60 * 24) sql = """ SELECT user_type, COUNT(*) AS count FROM ( @@ -1496,7 +1496,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): await self.db_pool.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, - self._clock.time_msec(), + self.clock.time_msec(), ) @wrap_as_background_process("account_validity_set_expiration_dates") @@ -1537,7 +1537,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): random value in the [now + period - d ; now + period] range, d being a delta equal to 10% of the validity period. """ - now_ms = self._clock.time_msec() + now_ms = self.clock.time_msec() assert self._account_validity_period is not None expiration_ts = now_ms + self._account_validity_period @@ -1608,7 +1608,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): Raises: StoreError if there was a problem updating this. """ - now = self._clock.time_msec() + now = self.clock.time_msec() await self.db_pool.simple_update_one( "access_tokens", @@ -1639,7 +1639,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): uses_allowed, pending, completed, expiry_time = res # Check if the token has expired - now = self._clock.time_msec() + now = self.clock.time_msec() if expiry_time and expiry_time < now: return False @@ -1771,7 +1771,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): return await self.db_pool.runInteraction( "select_registration_tokens", select_registration_tokens_txn, - self._clock.time_msec(), + self.clock.time_msec(), valid, ) @@ -2251,7 +2251,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): "consume_login_token", self._consume_login_token, token, - self._clock.time_msec(), + self.clock.time_msec(), ) async def invalidate_login_tokens_by_session_id( @@ -2271,7 +2271,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): "auth_provider_id": auth_provider_id, "auth_provider_session_id": auth_provider_session_id, }, - updatevalues={"used_ts": self._clock.time_msec()}, + updatevalues={"used_ts": self.clock.time_msec()}, desc="invalidate_login_tokens_by_session_id", ) @@ -2640,7 +2640,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): ): super().__init__(database, db_conn, hs) - self._clock = hs.get_clock() self.config = hs.config self.db_pool.updates.register_background_index_update( @@ -2761,7 +2760,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Create a background job for removing expired login tokens if hs.config.worker.run_background_tasks: - self._clock.looping_call( + self.clock.looping_call( self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS ) @@ -2790,7 +2789,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): The token ID """ next_id = self._access_tokens_id_gen.get_next() - now = self._clock.time_msec() + now = self.clock.time_msec() await self.db_pool.simple_insert( "access_tokens", @@ -2874,7 +2873,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): keyvalues={"name": user_id}, updatevalues={ "consent_version": consent_version, - "consent_ts": self._clock.time_msec(), + "consent_ts": self.clock.time_msec(), }, ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -2986,7 +2985,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, - updatevalues={"validated_at": self._clock.time_msec()}, + updatevalues={"validated_at": self.clock.time_msec()}, ) return next_link @@ -3064,7 +3063,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # We keep the expired tokens for an extra 5 minutes so we can measure how many # times a token is being used after its expiry - now = self._clock.time_msec() + now = self.clock.time_msec() await self.db_pool.runInteraction( "delete_expired_login_tokens", _delete_expired_login_tokens_txn, diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 9db2e14a06..65caf4b1ea 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -1002,7 +1002,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ with Measure( - self._clock, + self.clock, name="get_joined_user_ids_from_state", server_name=self.server_name, ): diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py index 8a5fa8386c..1154bb2d59 100644 --- a/synapse/storage/databases/main/session.py +++ b/synapse/storage/databases/main/session.py @@ -55,7 +55,7 @@ class SessionStore(SQLBaseStore): # Create a background job for culling expired sessions. if hs.config.worker.run_background_tasks: - self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000) + self.clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000) async def create_session( self, session_type: str, value: JsonDict, expiry_ms: int @@ -133,7 +133,7 @@ class SessionStore(SQLBaseStore): _get_session, session_type, session_id, - self._clock.time_msec(), + self.clock.time_msec(), ) @wrap_as_background_process("delete_expired_sessions") @@ -147,5 +147,5 @@ class SessionStore(SQLBaseStore): await self.db_pool.runInteraction( "delete_expired_sessions", _delete_expired_sessions_txn, - self._clock.time_msec(), + self.clock.time_msec(), ) diff --git a/synapse/storage/databases/main/sliding_sync.py b/synapse/storage/databases/main/sliding_sync.py index f7af3e88d3..c0c5087b13 100644 --- a/synapse/storage/databases/main/sliding_sync.py +++ b/synapse/storage/databases/main/sliding_sync.py @@ -201,7 +201,7 @@ class SlidingSyncStore(SQLBaseStore): "user_id": user_id, "effective_device_id": device_id, "conn_id": conn_id, - "created_ts": self._clock.time_msec(), + "created_ts": self.clock.time_msec(), }, returning=("connection_key",), ) @@ -212,7 +212,7 @@ class SlidingSyncStore(SQLBaseStore): table="sliding_sync_connection_positions", values={ "connection_key": connection_key, - "created_ts": self._clock.time_msec(), + "created_ts": self.clock.time_msec(), }, returning=("connection_position",), ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index bfc324b80d..41c9483927 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -81,11 +81,11 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: - self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) + self.clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) @wrap_as_background_process("cleanup_transactions") async def _cleanup_transactions(self) -> None: - now = self._clock.time_msec() + now = self.clock.time_msec() day_ago = now - 24 * 60 * 60 * 1000 def _cleanup_transactions_txn(txn: LoggingTransaction) -> None: @@ -160,7 +160,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): insertion_values={ "response_code": code, "response_json": db_binary_type(encode_canonical_json(response_dict)), - "ts": self._clock.time_msec(), + "ts": self.clock.time_msec(), }, desc="set_received_txn_response", ) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 9b3b7e086f..b62f3e6f5b 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -125,6 +125,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache( name="*stateGroupCache*", + clock=hs.get_clock(), server_name=self.server_name, # TODO: this hasn't been tuned yet max_entries=50000, @@ -132,6 +133,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_members_cache: DictionaryCache[int, StateKey, str] = ( DictionaryCache( name="*stateGroupMembersCache*", + clock=hs.get_clock(), server_name=self.server_name, max_entries=500000, ) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 1f90988525..2a167f209c 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -55,7 +55,6 @@ from typing_extensions import Concatenate, ParamSpec, Unpack from twisted.internet import defer from twisted.internet.defer import CancelledError -from twisted.internet.interfaces import IReactorTime from twisted.python.failure import Failure from synapse.logging.context import ( @@ -549,10 +548,9 @@ class Linearizer: def __init__( self, - *, name: str, - max_count: int = 1, clock: Clock, + max_count: int = 1, ): """ Args: @@ -772,7 +770,11 @@ class ReadWriteLock: def timeout_deferred( - deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime + *, + deferred: "defer.Deferred[_T]", + timeout: float, + cancel_on_shutdown: bool = True, + clock: Clock, ) -> "defer.Deferred[_T]": """The in built twisted `Deferred.addTimeout` fails to time out deferreds that have a canceller that throws exceptions. This method creates a new @@ -790,7 +792,13 @@ def timeout_deferred( Args: deferred: The Deferred to potentially timeout. timeout: Timeout in seconds - reactor: The twisted reactor to use + cancel_on_shutdown: Whether this call should be tracked for cleanup during + shutdown. In general, all calls should be tracked. There may be a use case + not to track calls with a `timeout` of 0 (or similarly short) since tracking + them may result in rapid insertions and removals of tracked calls + unnecessarily. But unless a specific instance of tracking proves to be an + issue, we can just track all delayed calls. + clock: The `Clock` instance used to track delayed calls. Returns: @@ -814,7 +822,10 @@ def timeout_deferred( if not new_d.called: new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,))) - delayed_call = reactor.callLater(timeout, time_it_out) + # We don't track these calls since they are short. + delayed_call = clock.call_later( + timeout, time_it_out, call_later_cancel_on_shutdown=cancel_on_shutdown + ) def convert_cancelled(value: Failure) -> Failure: # if the original deferred was cancelled, and our timeout has fired, then @@ -956,9 +967,9 @@ class AwakenableSleeper: currently sleeping. """ - def __init__(self, reactor: IReactorTime) -> None: + def __init__(self, clock: Clock) -> None: self._streams: Dict[str, Set[defer.Deferred[None]]] = {} - self._reactor = reactor + self._clock = clock def wake(self, name: str) -> None: """Wake everything related to `name` that is currently sleeping.""" @@ -977,7 +988,11 @@ class AwakenableSleeper: # Create a deferred that gets called in N seconds sleep_deferred: "defer.Deferred[None]" = defer.Deferred() - call = self._reactor.callLater(delay_ms / 1000, sleep_deferred.callback, None) + call = self._clock.call_later( + delay_ms / 1000, + sleep_deferred.callback, + None, + ) # Create a deferred that will get called if `wake` is called with # the same `name`. @@ -1011,8 +1026,8 @@ class AwakenableSleeper: class DeferredEvent: """Like threading.Event but for async code""" - def __init__(self, reactor: IReactorTime) -> None: - self._reactor = reactor + def __init__(self, clock: Clock) -> None: + self._clock = clock self._deferred: "defer.Deferred[None]" = defer.Deferred() def set(self) -> None: @@ -1032,7 +1047,11 @@ class DeferredEvent: # Create a deferred that gets called in N seconds sleep_deferred: "defer.Deferred[None]" = defer.Deferred() - call = self._reactor.callLater(timeout_seconds, sleep_deferred.callback, None) + call = self._clock.call_later( + timeout_seconds, + sleep_deferred.callback, + None, + ) try: await make_deferred_yieldable( diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py index 4c4037412a..f77301afd8 100644 --- a/synapse/util/batching_queue.py +++ b/synapse/util/batching_queue.py @@ -21,6 +21,7 @@ import logging from typing import ( + TYPE_CHECKING, Awaitable, Callable, Dict, @@ -38,9 +39,11 @@ from twisted.internet import defer from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.clock import Clock +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -97,12 +100,13 @@ class BatchingQueue(Generic[V, R]): self, *, name: str, - server_name: str, + hs: "HomeServer", clock: Clock, process_batch_callback: Callable[[List[V]], Awaitable[R]], ): self._name = name - self.server_name = server_name + self.hs = hs + self.server_name = hs.hostname self._clock = clock # The set of keys currently being processed. @@ -127,6 +131,14 @@ class BatchingQueue(Generic[V, R]): name=self._name, **{SERVER_NAME_LABEL: self.server_name} ) + def shutdown(self) -> None: + """ + Prepares the object for garbage collection by removing any handed out + references. + """ + number_queued.remove(self._name, self.server_name) + number_of_keys.remove(self._name, self.server_name) + async def add_to_queue(self, value: V, key: Hashable = ()) -> R: """Adds the value to the queue with the given key, returning the result of the processing function for the batch that included the given value. @@ -145,9 +157,7 @@ class BatchingQueue(Generic[V, R]): # If we're not currently processing the key fire off a background # process to start processing. if key not in self._processing_keys: - run_as_background_process( - self._name, self.server_name, self._process_queue, key - ) + self.hs.run_as_background_process(self._name, self._process_queue, key) with self._number_in_flight_metric.track_inprogress(): return await make_deferred_yieldable(d) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 710a29e3f0..08ff842af0 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -244,7 +244,7 @@ def register_cache( collect_callback=collect_callback, ) metric_name = "cache_%s_%s_%s" % (cache_type, cache_name, server_name) - CACHE_METRIC_REGISTRY.register_hook(metric_name, metric.collect) + CACHE_METRIC_REGISTRY.register_hook(server_name, metric_name, metric.collect) return metric diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 92d446ce2a..016acbac71 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -47,6 +47,7 @@ from synapse.metrics import SERVER_NAME_LABEL from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry +from synapse.util.clock import Clock cache_pending_metric = Gauge( "synapse_util_caches_cache_pending", @@ -82,6 +83,7 @@ class DeferredCache(Generic[KT, VT]): self, *, name: str, + clock: Clock, server_name: str, max_entries: int = 1000, tree: bool = False, @@ -103,6 +105,7 @@ class DeferredCache(Generic[KT, VT]): prune_unread_entries: If True, cache entries that haven't been read recently will be evicted from the cache in the background. Set to False to opt-out of this behaviour. + clock: The homeserver `Clock` instance """ cache_type = TreeCache if tree else dict @@ -120,6 +123,7 @@ class DeferredCache(Generic[KT, VT]): # a Deferred. self.cache: LruCache[KT, VT] = LruCache( max_size=max_entries, + clock=clock, server_name=server_name, cache_name=name, cache_type=cache_type, diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 47b8f4ddc8..6e3c8eada9 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -53,6 +53,7 @@ from synapse.util import unwrapFirstError from synapse.util.async_helpers import delay_cancellation from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.lrucache import LruCache +from synapse.util.clock import Clock logger = logging.getLogger(__name__) @@ -154,13 +155,20 @@ class _CacheDescriptorBase: ) -class HasServerName(Protocol): +class HasServerNameAndClock(Protocol): server_name: str """ The homeserver name that this cache is associated with (used to label the metric) (`hs.hostname`). """ + clock: Clock + """ + The homeserver clock instance used to track delayed and looping calls. Important to + be able to fully cleanup the homeserver instance on server shutdown. + (`hs.get_clock()`). + """ + class DeferredCacheDescriptor(_CacheDescriptorBase): """A method decorator that applies a memoizing cache around the function. @@ -239,7 +247,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): self.prune_unread_entries = prune_unread_entries def __get__( - self, obj: Optional[HasServerName], owner: Optional[Type] + self, obj: Optional[HasServerNameAndClock], owner: Optional[Type] ) -> Callable[..., "defer.Deferred[Any]"]: # We need access to instance-level `obj.server_name` attribute assert obj is not None, ( @@ -249,9 +257,13 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): assert obj.server_name is not None, ( "The `server_name` attribute must be set on the object where `@cached` decorator is used." ) + assert obj.clock is not None, ( + "The `clock` attribute must be set on the object where `@cached` decorator is used." + ) cache: DeferredCache[CacheKey, Any] = DeferredCache( name=self.name, + clock=obj.clock, server_name=obj.server_name, max_entries=self.max_entries, tree=self.tree, diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 168ddc51cd..eb5493d322 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -37,6 +37,7 @@ import attr from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache +from synapse.util.clock import Clock logger = logging.getLogger(__name__) @@ -127,10 +128,13 @@ class DictionaryCache(Generic[KT, DKT, DV]): for the '2' dict key. """ - def __init__(self, *, name: str, server_name: str, max_entries: int = 1000): + def __init__( + self, *, name: str, clock: Clock, server_name: str, max_entries: int = 1000 + ): """ Args: name + clock: The homeserver `Clock` instance server_name: The homeserver name that this cache is associated with (used to label the metric) (`hs.hostname`). max_entries @@ -160,6 +164,7 @@ class DictionaryCache(Generic[KT, DKT, DV]): Union[_PerKeyValue, Dict[DKT, DV]], ] = LruCache( max_size=max_entries, + clock=clock, server_name=server_name, cache_name=name, cache_type=TreeCache, diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 305af5051c..29ce6c0a77 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -21,17 +21,29 @@ import logging from collections import OrderedDict -from typing import Any, Generic, Iterable, Literal, Optional, TypeVar, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterable, + Literal, + Optional, + TypeVar, + Union, + overload, +) import attr from twisted.internet import defer from synapse.config import cache as cache_config -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import EvictionReason, register_cache from synapse.util.clock import Clock +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -49,6 +61,7 @@ class ExpiringCache(Generic[KT, VT]): *, cache_name: str, server_name: str, + hs: "HomeServer", clock: Clock, max_len: int = 0, expiry_ms: int = 0, @@ -99,9 +112,7 @@ class ExpiringCache(Generic[KT, VT]): return def f() -> "defer.Deferred[None]": - return run_as_background_process( - "prune_cache", server_name, self._prune_cache - ) + return hs.run_as_background_process("prune_cache", self._prune_cache) self._clock.looping_call(f, self._expiry_ms / 2) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 2d4cde19a5..324acb728a 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -45,14 +45,10 @@ from typing import ( overload, ) -from twisted.internet import defer, reactor +from twisted.internet import defer from synapse.config import cache as cache_config -from synapse.metrics.background_process_metrics import ( - run_as_background_process, -) from synapse.metrics.jemalloc import get_jemalloc_stats -from synapse.types import ISynapseThreadlessReactor from synapse.util import caches from synapse.util.caches import CacheMetric, EvictionReason, register_cache from synapse.util.caches.treecache import ( @@ -123,6 +119,7 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node() def _expire_old_entries( server_name: str, + hs: "HomeServer", clock: Clock, expiry_seconds: float, autotune_config: Optional[dict], @@ -228,9 +225,8 @@ def _expire_old_entries( logger.info("Dropped %d items from caches", i) - return run_as_background_process( + return hs.run_as_background_process( "LruCache._expire_old_entries", - server_name, _internal_expire_old_entries, clock, expiry_seconds, @@ -261,6 +257,7 @@ def setup_expire_lru_cache_entries(hs: "HomeServer") -> None: _expire_old_entries, 30 * 1000, server_name, + hs, clock, expiry_time, hs.config.caches.cache_autotuning, @@ -404,13 +401,13 @@ class LruCache(Generic[KT, VT]): self, *, max_size: int, + clock: Clock, server_name: str, cache_name: str, cache_type: Type[Union[dict, TreeCache]] = dict, size_callback: Optional[Callable[[VT], int]] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, - clock: Optional[Clock] = None, prune_unread_entries: bool = True, extra_index_cb: Optional[Callable[[KT, VT], KT]] = None, ): ... @@ -420,13 +417,13 @@ class LruCache(Generic[KT, VT]): self, *, max_size: int, + clock: Clock, server_name: str, cache_name: Literal[None] = None, cache_type: Type[Union[dict, TreeCache]] = dict, size_callback: Optional[Callable[[VT], int]] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, - clock: Optional[Clock] = None, prune_unread_entries: bool = True, extra_index_cb: Optional[Callable[[KT, VT], KT]] = None, ): ... @@ -435,13 +432,13 @@ class LruCache(Generic[KT, VT]): self, *, max_size: int, + clock: Clock, server_name: str, cache_name: Optional[str] = None, cache_type: Type[Union[dict, TreeCache]] = dict, size_callback: Optional[Callable[[VT], int]] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, - clock: Optional[Clock] = None, prune_unread_entries: bool = True, extra_index_cb: Optional[Callable[[KT, VT], KT]] = None, ): @@ -492,15 +489,6 @@ class LruCache(Generic[KT, VT]): Note: The new key does not have to be unique. """ - # Default `clock` to something sensible. Note that we rename it to - # `real_clock` so that mypy doesn't think its still `Optional`. - if clock is None: - real_clock = Clock( - cast(ISynapseThreadlessReactor, reactor), server_name=server_name - ) - else: - real_clock = clock - cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type() self.cache = cache # Used for introspection. self.apply_cache_factor_from_config = apply_cache_factor_from_config @@ -592,7 +580,7 @@ class LruCache(Generic[KT, VT]): key, value, weak_ref_to_self, - real_clock, + clock, callbacks, prune_unread_entries, ) @@ -610,7 +598,7 @@ class LruCache(Generic[KT, VT]): metrics.inc_memory_usage(node.memory) def move_node_to_front(node: _Node[KT, VT]) -> None: - node.move_to_front(real_clock, list_root) + node.move_to_front(clock, list_root) def delete_node(node: _Node[KT, VT]) -> int: node.drop_from_lists() diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 79e34262df..3d39357236 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -198,7 +198,17 @@ class ResponseCache(Generic[KV]): # the should_cache bit, we leave it in the cache for now and schedule # its removal later. if self.timeout_sec and context.should_cache: - self.clock.call_later(self.timeout_sec, self._entry_timeout, key) + self.clock.call_later( + self.timeout_sec, + self._entry_timeout, + key, + # We don't need to track these calls since they don't hold any strong + # references which would keep the `HomeServer` in memory after shutdown. + # We don't want to track these because they can get cancelled really + # quickly and thrash the tracking mechanism, ie. during repeated calls + # to /sync. + call_later_cancel_on_shutdown=False, + ) else: # otherwise, remove the result immediately. self.unset(key) diff --git a/synapse/util/clock.py b/synapse/util/clock.py index e85af17005..5e65cf32a4 100644 --- a/synapse/util/clock.py +++ b/synapse/util/clock.py @@ -17,10 +17,12 @@ from typing import ( Any, Callable, + Dict, + List, ) -import attr from typing_extensions import ParamSpec +from zope.interface import implementer from twisted.internet import defer, task from twisted.internet.defer import Deferred @@ -34,24 +36,54 @@ from synapse.util import log_failure P = ParamSpec("P") -@attr.s(slots=True) class Clock: """ A Clock wraps a Twisted reactor and provides utilities on top of it. + This clock should be used in place of calls to the base reactor wherever `LoopingCall` + or `DelayedCall` are made (such as when calling `reactor.callLater`. This is to + ensure the calls made by this `HomeServer` instance are tracked and can be cleaned + up during `HomeServer.shutdown()`. + + We enforce usage of this clock instead of using the reactor directly via lints in + `scripts-dev/mypy_synapse_plugin.py`. + + Args: reactor: The Twisted reactor to use. """ - _reactor: ISynapseThreadlessReactor = attr.ib() - _server_name: str = attr.ib() + _reactor: ISynapseThreadlessReactor + + def __init__(self, reactor: ISynapseThreadlessReactor, server_name: str) -> None: + self._reactor = reactor + self._server_name = server_name + + self._delayed_call_id: int = 0 + """Unique ID used to track delayed calls""" + + self._looping_calls: List[LoopingCall] = [] + """List of active looping calls""" + + self._call_id_to_delayed_call: Dict[int, IDelayedCall] = {} + """Mapping from unique call ID to delayed call""" + + self._is_shutdown = False + """Whether shutdown has been requested by the HomeServer""" + + def shutdown(self) -> None: + self._is_shutdown = True + self.cancel_all_looping_calls() + self.cancel_all_delayed_calls() async def sleep(self, seconds: float) -> None: d: defer.Deferred[float] = defer.Deferred() # Start task in the `sentinel` logcontext, to avoid leaking the current context # into the reactor once it finishes. with context.PreserveLoggingContext(): - self._reactor.callLater(seconds, d.callback, seconds) + # We can ignore the lint here since this class is the one location callLater should + # be called. + self._reactor.callLater(seconds, d.callback, seconds) # type: ignore[call-later-not-tracked] await d def time(self) -> float: @@ -124,6 +156,9 @@ class Clock: ) -> LoopingCall: """Common functionality for `looping_call` and `looping_call_now`""" + if self._is_shutdown: + raise Exception("Cannot start looping call. Clock has been shutdown") + def wrapped_f(*args: P.args, **kwargs: P.kwargs) -> Deferred: assert context.current_context() is context.SENTINEL_CONTEXT, ( "Expected `looping_call` callback from the reactor to start with the sentinel logcontext " @@ -155,7 +190,9 @@ class Clock: # logcontext to the reactor return context.run_in_background(f, *args, **kwargs) - call = task.LoopingCall(wrapped_f, *args, **kwargs) + # We can ignore the lint here since this is the one location LoopingCall's + # should be created. + call = task.LoopingCall(wrapped_f, *args, **kwargs) # type: ignore[prefer-synapse-clock-looping-call] call.clock = self._reactor # If `now=true`, the function will be called here immediately so we need to be # in the sentinel context now. @@ -165,10 +202,32 @@ class Clock: with context.PreserveLoggingContext(): d = call.start(msec / 1000.0, now=now) d.addErrback(log_failure, "Looping call died", consumeErrors=False) + self._looping_calls.append(call) return call + def cancel_all_looping_calls(self, consumeErrors: bool = True) -> None: + """ + Stop all running looping calls. + + Args: + consumeErrors: Whether to re-raise errors encountered when cancelling the + scheduled call. + """ + for call in self._looping_calls: + try: + call.stop() + except Exception: + if not consumeErrors: + raise + self._looping_calls.clear() + def call_later( - self, delay: float, callback: Callable, *args: Any, **kwargs: Any + self, + delay: float, + callback: Callable, + *args: Any, + call_later_cancel_on_shutdown: bool = True, + **kwargs: Any, ) -> IDelayedCall: """Call something later @@ -180,39 +239,78 @@ class Clock: delay: How long to wait in seconds. callback: Function to call *args: Postional arguments to pass to function. + call_later_cancel_on_shutdown: Whether this call should be tracked for cleanup during + shutdown. In general, all calls should be tracked. There may be a use case + not to track calls with a `timeout` of 0 (or similarly short) since tracking + them may result in rapid insertions and removals of tracked calls + unnecessarily. But unless a specific instance of tracking proves to be an + issue, we can just track all delayed calls. **kwargs: Key arguments to pass to function. """ - def wrapped_callback(*args: Any, **kwargs: Any) -> None: - assert context.current_context() is context.SENTINEL_CONTEXT, ( - "Expected `call_later` callback from the reactor to start with the sentinel logcontext " - f"but saw {context.current_context()}. In other words, another task shouldn't have " - "leaked their logcontext to us." - ) + if self._is_shutdown: + raise Exception("Cannot start delayed call. Clock has been shutdown") - # Because this is a callback from the reactor, we will be using the - # `sentinel` log context at this point. We want the function to log with - # some logcontext as we want to know which server the logs came from. - # - # We use `PreserveLoggingContext` to prevent our new `call_later` - # logcontext from finishing as soon as we exit this function, in case `f` - # returns an awaitable/deferred which would continue running and may try to - # restore the `loop_call` context when it's done (because it's trying to - # adhere to the Synapse logcontext rules.) - # - # This also ensures that we return to the `sentinel` context when we exit - # this function and yield control back to the reactor to avoid leaking the - # current logcontext to the reactor (which would then get picked up and - # associated with the next thing the reactor does) - with context.PreserveLoggingContext( - context.LoggingContext(name="call_later", server_name=self._server_name) - ): - # We use `run_in_background` to reset the logcontext after `f` (or the - # awaitable returned by `f`) completes to avoid leaking the current - # logcontext to the reactor - context.run_in_background(callback, *args, **kwargs) + def create_wrapped_callback( + track_for_shutdown_cancellation: bool, + ) -> Callable[P, None]: + def wrapped_callback(*args: Any, **kwargs: Any) -> None: + assert context.current_context() is context.SENTINEL_CONTEXT, ( + "Expected `call_later` callback from the reactor to start with the sentinel logcontext " + f"but saw {context.current_context()}. In other words, another task shouldn't have " + "leaked their logcontext to us." + ) - return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) + # Because this is a callback from the reactor, we will be using the + # `sentinel` log context at this point. We want the function to log with + # some logcontext as we want to know which server the logs came from. + # + # We use `PreserveLoggingContext` to prevent our new `call_later` + # logcontext from finishing as soon as we exit this function, in case `f` + # returns an awaitable/deferred which would continue running and may try to + # restore the `loop_call` context when it's done (because it's trying to + # adhere to the Synapse logcontext rules.) + # + # This also ensures that we return to the `sentinel` context when we exit + # this function and yield control back to the reactor to avoid leaking the + # current logcontext to the reactor (which would then get picked up and + # associated with the next thing the reactor does) + try: + with context.PreserveLoggingContext( + context.LoggingContext( + name="call_later", server_name=self._server_name + ) + ): + # We use `run_in_background` to reset the logcontext after `f` (or the + # awaitable returned by `f`) completes to avoid leaking the current + # logcontext to the reactor + context.run_in_background(callback, *args, **kwargs) + finally: + if track_for_shutdown_cancellation: + # We still want to remove the call from the tracking map. Even if + # the callback raises an exception. + self._call_id_to_delayed_call.pop(call_id) + + return wrapped_callback + + if call_later_cancel_on_shutdown: + call_id = self._delayed_call_id + self._delayed_call_id = self._delayed_call_id + 1 + + # We can ignore the lint here since this class is the one location callLater + # should be called. + call = self._reactor.callLater( + delay, create_wrapped_callback(True), *args, **kwargs + ) # type: ignore[call-later-not-tracked] + call = DelayedCallWrapper(call, call_id, self) + self._call_id_to_delayed_call[call_id] = call + return call + else: + # We can ignore the lint here since this class is the one location callLater should + # be called. + return self._reactor.callLater( + delay, create_wrapped_callback(False), *args, **kwargs + ) # type: ignore[call-later-not-tracked] def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None: try: @@ -221,6 +319,24 @@ class Clock: if not ignore_errs: raise + def cancel_all_delayed_calls(self, ignore_errs: bool = True) -> None: + """ + Stop all scheduled calls that were marked with `cancel_on_shutdown` when they were created. + + Args: + ignore_errs: Whether to re-raise errors encountered when cancelling the + scheduled call. + """ + # We make a copy here since calling `cancel()` on a delayed_call + # will result in the call removing itself from the map mid-iteration. + for call in list(self._call_id_to_delayed_call.values()): + try: + call.cancel() + except Exception: + if not ignore_errs: + raise + self._call_id_to_delayed_call.clear() + def call_when_running( self, callback: Callable[P, object], @@ -285,7 +401,7 @@ class Clock: callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs, - ) -> None: + ) -> Any: """ Add a function to be called when a system event occurs. @@ -299,6 +415,9 @@ class Clock: callback: Function to call *args: Postional arguments to pass to function. **kwargs: Key arguments to pass to function. + + Returns: + an ID that can be used to remove this call with `reactor.removeSystemEventTrigger`. """ def wrapped_callback(*args: Any, **kwargs: Any) -> None: @@ -334,6 +453,50 @@ class Clock: # We can ignore the lint here since this class is the one location # `addSystemEventTrigger` should be called. - self._reactor.addSystemEventTrigger( + return self._reactor.addSystemEventTrigger( phase, event_type, wrapped_callback, *args, **kwargs ) # type: ignore[prefer-synapse-clock-add-system-event-trigger] + + +@implementer(IDelayedCall) +class DelayedCallWrapper: + """Wraps an `IDelayedCall` so that we can intercept the call to `cancel()` and + properly cleanup the delayed call from the tracking map of the `Clock`. + + args: + delayed_call: The actual `IDelayedCall` + call_id: Unique identifier for this delayed call + clock: The clock instance tracking this call + """ + + def __init__(self, delayed_call: IDelayedCall, call_id: int, clock: Clock): + self.delayed_call = delayed_call + self.call_id = call_id + self.clock = clock + + def cancel(self) -> None: + """Remove the call from the tracking map and propagate the call to the + underlying delayed_call. + """ + self.delayed_call.cancel() + try: + self.clock._call_id_to_delayed_call.pop(self.call_id) + except KeyError: + # If the delayed call isn't being tracked anymore we can just move on. + pass + + def getTime(self) -> float: + """Propagate the call to the underlying delayed_call.""" + return self.delayed_call.getTime() + + def delay(self, secondsLater: float) -> None: + """Propagate the call to the underlying delayed_call.""" + self.delayed_call.delay(secondsLater) + + def reset(self, secondsFromNow: float) -> None: + """Propagate the call to the underlying delayed_call.""" + self.delayed_call.reset(secondsFromNow) + + def active(self) -> bool: + """Propagate the call to the underlying delayed_call.""" + return self.delayed_call.active() diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index f48ae3373c..dec6536e4e 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -20,6 +20,7 @@ # import logging from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -36,10 +37,13 @@ from typing_extensions import ParamSpec from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID from synapse.util.async_helpers import maybe_awaitable +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) @@ -58,13 +62,13 @@ class Distributor: model will do for today. """ - def __init__(self, server_name: str) -> None: + def __init__(self, hs: "HomeServer") -> None: """ Args: server_name: The homeserver name of the server (used to label metrics) (this should be `hs.hostname`). """ - self.server_name = server_name + self.hs = hs self.signals: Dict[str, Signal] = {} self.pre_registration: Dict[str, List[Callable]] = {} @@ -97,8 +101,8 @@ class Distributor: if name not in self.signals: raise KeyError("%r does not have a signal named %s" % (self, name)) - run_as_background_process( - name, self.server_name, self.signals[name].fire, *args, **kwargs + self.hs.run_as_background_process( + name, self.signals[name].fire, *args, **kwargs ) diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index c4f3c8b965..7b6ad0e459 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -293,21 +293,46 @@ class DynamicCollectorRegistry(CollectorRegistry): def __init__(self) -> None: super().__init__() - self._pre_update_hooks: Dict[str, Callable[[], None]] = {} + self._server_name_to_pre_update_hooks: Dict[ + str, Dict[str, Callable[[], None]] + ] = {} + """ + Mapping of server name to a mapping of metric name to metric pre-update + hook + """ def collect(self) -> Generator[Metric, None, None]: """ Collects metrics, calling pre-update hooks first. """ - for pre_update_hook in self._pre_update_hooks.values(): - pre_update_hook() + for pre_update_hooks in self._server_name_to_pre_update_hooks.values(): + for pre_update_hook in pre_update_hooks.values(): + pre_update_hook() yield from super().collect() - def register_hook(self, metric_name: str, hook: Callable[[], None]) -> None: + def register_hook( + self, server_name: str, metric_name: str, hook: Callable[[], None] + ) -> None: """ Registers a hook that is called before metric collection. """ - self._pre_update_hooks[metric_name] = hook + server_hooks = self._server_name_to_pre_update_hooks.setdefault(server_name, {}) + if server_hooks.get(metric_name) is not None: + # TODO: This should be an `assert` since registering the same metric name + # multiple times will clobber the old metric. + # We currently rely on this behaviour as we instantiate multiple + # `SyncRestServlet`, one per listener, and in the `__init__` we setup a new + # LruCache. + # Once the above behaviour is changed, this should be changed to an `assert`. + logger.error( + "Metric named %s already registered for server %s", + metric_name, + server_name, + ) + server_hooks[metric_name] = hook + + def unregister_hooks_for_homeserver(self, server_name: str) -> None: + self._server_name_to_pre_update_hooks.pop(server_name, None) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 695eb462bf..756677fe6c 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -419,4 +419,7 @@ class _PerHostRatelimiter: except KeyError: pass - self.clock.call_later(0.0, start_next_request) + self.clock.call_later( + 0.0, + start_next_request, + ) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 42a0cc7aa8..96fe2bd566 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -24,7 +24,6 @@ from types import TracebackType from typing import TYPE_CHECKING, Any, Optional, Type from synapse.api.errors import CodeMessageException -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage import DataStore from synapse.types import StrCollection from synapse.util.clock import Clock @@ -32,6 +31,7 @@ from synapse.util.clock import Clock if TYPE_CHECKING: from synapse.notifier import Notifier from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -62,6 +62,7 @@ async def get_retry_limiter( *, destination: str, our_server_name: str, + hs: "HomeServer", clock: Clock, store: DataStore, ignore_backoff: bool = False, @@ -124,6 +125,7 @@ async def get_retry_limiter( return RetryDestinationLimiter( destination=destination, our_server_name=our_server_name, + hs=hs, clock=clock, store=store, failure_ts=failure_ts, @@ -163,6 +165,7 @@ class RetryDestinationLimiter: *, destination: str, our_server_name: str, + hs: "HomeServer", clock: Clock, store: DataStore, failure_ts: Optional[int], @@ -181,6 +184,7 @@ class RetryDestinationLimiter: Args: destination our_server_name: Our homeserver name (used to label metrics) (`hs.hostname`) + hs: The homeserver instance clock store failure_ts: when this destination started failing (in ms since @@ -197,6 +201,7 @@ class RetryDestinationLimiter: error code. """ self.our_server_name = our_server_name + self.hs = hs self.clock = clock self.store = store self.destination = destination @@ -331,6 +336,4 @@ class RetryDestinationLimiter: logger.exception("Failed to store destination_retry_timings") # we deliberately do this in the background. - run_as_background_process( - "store_retry_timings", self.our_server_name, store_retry_timings - ) + self.hs.run_as_background_process("store_retry_timings", store_retry_timings) diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py index 0539989320..7443d4e097 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py @@ -32,7 +32,6 @@ from synapse.logging.context import ( ) from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.types import JsonMapping, ScheduledTask, TaskStatus @@ -107,10 +106,8 @@ class TaskScheduler: OCCASIONAL_REPORT_INTERVAL_MS = 5 * 60 * 1000 # 5 minutes def __init__(self, hs: "HomeServer"): - self._hs = hs - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.hs = hs # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname self._store = hs.get_datastores().main self._clock = hs.get_clock() self._running_tasks: Set[str] = set() @@ -215,7 +212,7 @@ class TaskScheduler: if self._run_background_tasks: self._launch_scheduled_tasks() else: - self._hs.get_replication_command_handler().send_new_active_task(task.id) + self.hs.get_replication_command_handler().send_new_active_task(task.id) return task.id @@ -362,7 +359,7 @@ class TaskScheduler: finally: self._launching_new_tasks = False - run_as_background_process("launch_scheduled_tasks", self.server_name, inner) + self.hs.run_as_background_process("launch_scheduled_tasks", inner) @wrap_as_background_process("clean_scheduled_tasks") async def _clean_scheduled_tasks(self) -> None: @@ -473,7 +470,10 @@ class TaskScheduler: occasional_status_call.stop() # Try launch a new task since we've finished with this one. - self._clock.call_later(0.1, self._launch_scheduled_tasks) + self._clock.call_later( + 0.1, + self._launch_scheduled_tasks, + ) if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: return @@ -493,4 +493,4 @@ class TaskScheduler: self._running_tasks.add(task.id) await self.update_task(task.id, status=TaskStatus.ACTIVE) - run_as_background_process(f"task-{task.action}", self.server_name, wrapper) + self.hs.run_as_background_process(f"task-{task.action}", wrapper) diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py index c3f3cceaa6..cf9c836e06 100644 --- a/synmark/suites/logging.py +++ b/synmark/suites/logging.py @@ -86,7 +86,9 @@ async def main(reactor: ISynapseReactor, loops: int) -> float: hs_config = Config() # To be able to sleep. - clock = Clock(reactor, server_name=hs_config.server.server_name) + # Ignore linter error here since we are running outside of the context of a + # Synapse `HomeServer`. + clock = Clock(reactor, server_name=hs_config.server.server_name) # type: ignore[multiple-internal-clocks] errors = StringIO() publisher = LogPublisher() diff --git a/synmark/suites/lrucache.py b/synmark/suites/lrucache.py index 6314035bd7..830a3daa8f 100644 --- a/synmark/suites/lrucache.py +++ b/synmark/suites/lrucache.py @@ -23,14 +23,19 @@ from pyperf import perf_counter from synapse.types import ISynapseReactor from synapse.util.caches.lrucache import LruCache +from synapse.util.clock import Clock async def main(reactor: ISynapseReactor, loops: int) -> float: """ Benchmark `loops` number of insertions into LruCache without eviction. """ + # Ignore linter error here since we are running outside of the context of a + # Synapse `HomeServer`. cache: LruCache[int, bool] = LruCache( - max_size=loops, server_name="synmark_benchmark" + max_size=loops, + clock=Clock(reactor, server_name="synmark_benchmark"), # type: ignore[multiple-internal-clocks] + server_name="synmark_benchmark", ) start = perf_counter() diff --git a/synmark/suites/lrucache_evict.py b/synmark/suites/lrucache_evict.py index b8cd589697..c67e0c9001 100644 --- a/synmark/suites/lrucache_evict.py +++ b/synmark/suites/lrucache_evict.py @@ -23,6 +23,7 @@ from pyperf import perf_counter from synapse.types import ISynapseReactor from synapse.util.caches.lrucache import LruCache +from synapse.util.clock import Clock async def main(reactor: ISynapseReactor, loops: int) -> float: @@ -30,8 +31,12 @@ async def main(reactor: ISynapseReactor, loops: int) -> float: Benchmark `loops` number of insertions into LruCache where half of them are evicted. """ + # Ignore linter error here since we are running outside of the context of a + # Synapse `HomeServer`. cache: LruCache[int, bool] = LruCache( - max_size=loops // 2, server_name="synmark_benchmark" + max_size=loops // 2, + clock=Clock(reactor, server_name="synmark_benchmark"), # type: ignore[multiple-internal-clocks] + server_name="synmark_benchmark", ) start = perf_counter() diff --git a/tests/app/test_homeserver_shutdown.py b/tests/app/test_homeserver_shutdown.py new file mode 100644 index 0000000000..d8119ba310 --- /dev/null +++ b/tests/app/test_homeserver_shutdown.py @@ -0,0 +1,193 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# Originally licensed under the Apache License, Version 2.0: +# . +# +# [This file includes modifications made by New Vector Limited] +# +# + +import gc +import weakref + +from synapse.app.homeserver import SynapseHomeServer +from synapse.storage.background_updates import UpdaterStatus + +from tests.server import ( + cleanup_test_reactor_system_event_triggers, + get_clock, + setup_test_homeserver, +) +from tests.unittest import HomeserverTestCase + + +class HomeserverCleanShutdownTestCase(HomeserverTestCase): + def setUp(self) -> None: + pass + + # NOTE: ideally we'd have another test to ensure we properly shutdown with + # real in-flight HTTP requests since those result in additional resources being + # setup that hold strong references to the homeserver. + # Mainly, the HTTP channel created by a real TCP connection from client to server + # is held open between requests and care needs to be taken in Twisted to ensure it is properly + # closed in a timely manner during shutdown. Simulating this behaviour in a unit test + # won't be as good as a proper integration test in complement. + + def test_clean_homeserver_shutdown(self) -> None: + """Ensure the `SynapseHomeServer` can be fully shutdown and garbage collected""" + self.reactor, self.clock = get_clock() + self.hs = setup_test_homeserver( + cleanup_func=self.addCleanup, + reactor=self.reactor, + homeserver_to_use=SynapseHomeServer, + clock=self.clock, + ) + self.wait_for_background_updates() + + hs_ref = weakref.ref(self.hs) + + # Run the reactor so any `callWhenRunning` functions can be cleared out. + self.reactor.run() + # This would normally happen as part of `HomeServer.shutdown` but the `MemoryReactor` + # we use in tests doesn't handle this properly (see doc comment) + cleanup_test_reactor_system_event_triggers(self.reactor) + + # Cleanup the homeserver. + self.get_success(self.hs.shutdown()) + + # Cleanup the internal reference in our test case + del self.hs + + # Force garbage collection. + gc.collect() + + # Ensure the `HomeServer` hs been garbage collected by attempting to use the + # weakref to it. + if hs_ref() is not None: + self.fail("HomeServer reference should not be valid at this point") + + # To help debug this test when it fails, it is useful to leverage the + # `objgraph` module. + # The following code serves as an example of what I have found to be useful + # when tracking down references holding the `SynapseHomeServer` in memory: + # + # all_objects = gc.get_objects() + # for obj in all_objects: + # try: + # # These are a subset of types that are typically involved with + # # holding the `HomeServer` in memory. You may want to inspect + # # other types as well. + # if isinstance(obj, DataStore): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # db_obj = obj + # if isinstance(obj, SynapseHomeServer): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # synapse_hs = obj + # if isinstance(obj, SynapseSite): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # sysite = obj + # if isinstance(obj, DatabasePool): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # dbpool = obj + # except Exception: + # pass + # + # print(sys.getrefcount(hs_ref()), "refs to", hs_ref()) + # + # # The following values for `max_depth` and `too_many` have been found to + # # render a useful amount of information without taking an overly long time + # # to generate the result. + # objgraph.show_backrefs(synapse_hs, max_depth=10, too_many=10) + + def test_clean_homeserver_shutdown_mid_background_updates(self) -> None: + """Ensure the `SynapseHomeServer` can be fully shutdown and garbage collected + before background updates have completed""" + self.reactor, self.clock = get_clock() + self.hs = setup_test_homeserver( + cleanup_func=self.addCleanup, + reactor=self.reactor, + homeserver_to_use=SynapseHomeServer, + clock=self.clock, + ) + + # Pump the background updates by a single iteration, just to ensure any extra + # resources it uses have been started. + store = weakref.proxy(self.hs.get_datastores().main) + self.get_success(store.db_pool.updates.do_next_background_update(False), by=0.1) + + hs_ref = weakref.ref(self.hs) + + # Run the reactor so any `callWhenRunning` functions can be cleared out. + self.reactor.run() + # This would normally happen as part of `HomeServer.shutdown` but the `MemoryReactor` + # we use in tests doesn't handle this properly (see doc comment) + cleanup_test_reactor_system_event_triggers(self.reactor) + + # Ensure the background updates are not complete. + self.assertNotEqual(store.db_pool.updates.get_status(), UpdaterStatus.COMPLETE) + + # Cleanup the homeserver. + self.get_success(self.hs.shutdown()) + + # Cleanup the internal reference in our test case + del self.hs + + # Force garbage collection. + gc.collect() + + # Ensure the `HomeServer` hs been garbage collected by attempting to use the + # weakref to it. + if hs_ref() is not None: + self.fail("HomeServer reference should not be valid at this point") + + # To help debug this test when it fails, it is useful to leverage the + # `objgraph` module. + # The following code serves as an example of what I have found to be useful + # when tracking down references holding the `SynapseHomeServer` in memory: + # + # all_objects = gc.get_objects() + # for obj in all_objects: + # try: + # # These are a subset of types that are typically involved with + # # holding the `HomeServer` in memory. You may want to inspect + # # other types as well. + # if isinstance(obj, DataStore): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # db_obj = obj + # if isinstance(obj, SynapseHomeServer): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # synapse_hs = obj + # if isinstance(obj, SynapseSite): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # sysite = obj + # if isinstance(obj, DatabasePool): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # dbpool = obj + # except Exception: + # pass + # + # print(sys.getrefcount(hs_ref()), "refs to", hs_ref()) + # + # # The following values for `max_depth` and `too_many` have been found to + # # render a useful amount of information without taking an overly long time + # # to generate the result. + # objgraph.show_backrefs(synapse_hs, max_depth=10, too_many=10) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 0385190f34..f4490a1a79 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -167,8 +167,9 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): ) -class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): +class ApplicationServiceSchedulerRecovererTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: + super().setUp() self.reactor, self.clock = get_clock() self.as_api = Mock() self.store = Mock() @@ -176,6 +177,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.callback = AsyncMock() self.recoverer = _Recoverer( server_name="test_server", + hs=self.hs, clock=self.clock, as_api=self.as_api, store=self.store, diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index f56d6044a9..74db2dab08 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -24,6 +24,7 @@ from synapse.config.cache import CacheConfig, add_resizable_cache from synapse.types import JsonDict from synapse.util.caches.lrucache import LruCache +from tests.server import get_clock from tests.unittest import TestCase @@ -32,6 +33,7 @@ class CacheConfigTests(TestCase): # Reset caches before each test since there's global state involved. self.config = CacheConfig(RootConfig()) self.config.reset() + _, self.clock = get_clock() def tearDown(self) -> None: # Also reset the caches after each test to leave state pristine. @@ -75,7 +77,9 @@ class CacheConfigTests(TestCase): the default cache size in the interim, and then resized once the config is loaded. """ - cache: LruCache = LruCache(max_size=100, server_name="test_server") + cache: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 50) @@ -96,7 +100,9 @@ class CacheConfigTests(TestCase): self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache: LruCache = LruCache(max_size=100, server_name="test_server") + cache: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 200) @@ -106,7 +112,9 @@ class CacheConfigTests(TestCase): the default cache size in the interim, and then resized to the new default cache size once the config is loaded. """ - cache: LruCache = LruCache(max_size=100, server_name="test_server") + cache: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 50) @@ -126,7 +134,9 @@ class CacheConfigTests(TestCase): self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache: LruCache = LruCache(max_size=100, server_name="test_server") + cache: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 150) @@ -145,15 +155,21 @@ class CacheConfigTests(TestCase): self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache_a: LruCache = LruCache(max_size=100, server_name="test_server") + cache_a: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) self.assertEqual(cache_a.max_size, 200) - cache_b: LruCache = LruCache(max_size=100, server_name="test_server") + cache_b: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor) self.assertEqual(cache_b.max_size, 300) - cache_c: LruCache = LruCache(max_size=100, server_name="test_server") + cache_c: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor) self.assertEqual(cache_c.max_size, 200) @@ -168,6 +184,7 @@ class CacheConfigTests(TestCase): cache: LruCache = LruCache( max_size=self.config.event_cache_size, + clock=self.clock, apply_cache_factor_from_config=False, server_name="test_server", ) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 6516b7db17..df36185b99 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -19,7 +19,17 @@ # # -from typing import Dict, Iterable, List, Optional +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + TypeVar, +) from unittest.mock import AsyncMock, Mock from parameterized import parameterized @@ -36,6 +46,7 @@ from synapse.appservice import ( TransactionUnusedFallbackKeys, ) from synapse.handlers.appservice import ApplicationServicesHandler +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client import login, receipts, register, room, sendtodevice from synapse.server import HomeServer from synapse.types import ( @@ -53,6 +64,11 @@ from tests.server import get_clock from tests.test_utils import event_injection from tests.unittest import override_config +if TYPE_CHECKING: + from typing_extensions import LiteralString + +R = TypeVar("R") + class AppServiceHandlerTestCase(unittest.TestCase): """Tests the ApplicationServicesHandler.""" @@ -64,6 +80,17 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.reactor, self.clock = get_clock() hs = Mock() + + def test_run_as_background_process( + desc: "LiteralString", + func: Callable[..., Awaitable[Optional[R]]], + *args: Any, + **kwargs: Any, + ) -> "defer.Deferred[Optional[R]]": + # Ignore linter error as this is used only for testing purposes (i.e. outside of Synapse). + return run_as_background_process(desc, "test_server", func, *args, **kwargs) # type: ignore[untracked-background-process] + + hs.run_as_background_process = test_run_as_background_process hs.get_datastores.return_value = Mock(main=self.mock_store) self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None) self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 4d2807151e..90c185bc3d 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -79,15 +79,17 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) -> HomeServer: # we mock out the keyring so as to skip the authentication check on the # federation API call. - mock_keyring = Mock(spec=["verify_json_for_server"]) + mock_keyring = Mock(spec=["verify_json_for_server", "shutdown"]) mock_keyring.verify_json_for_server = AsyncMock(return_value=True) + mock_keyring.shutdown = Mock() # we mock out the federation client too self.mock_federation_client = AsyncMock(spec=["put_json"]) self.mock_federation_client.put_json.return_value = (200, "OK") self.mock_federation_client.agent = MatrixFederationAgent( server_name="OUR_STUB_HOMESERVER_NAME", - reactor=reactor, + reactor=self.reactor, + clock=self.clock, tls_client_options_factory=None, user_agent=b"SynapseInTrialTest/0.0.0", ip_allowlist=None, @@ -96,7 +98,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) # the tests assume that we are starting at unix time 1000 - reactor.pump((1000,)) + self.reactor.pump((1000,)) self.mock_hs_notifier = Mock() hs = self.setup_test_homeserver( diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index a5e1b7c284..c66ca489a4 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -65,7 +65,7 @@ from synapse.util.caches.ttlcache import TTLCache from tests import unittest from tests.http import dummy_address, get_test_ca_cert_file, wrap_server_factory_for_tls -from tests.server import FakeTransport, ThreadedMemoryReactorClock +from tests.server import FakeTransport, get_clock from tests.utils import checked_cast, default_config logger = logging.getLogger(__name__) @@ -73,7 +73,7 @@ logger = logging.getLogger(__name__) class MatrixFederationAgentTests(unittest.TestCase): def setUp(self) -> None: - self.reactor = ThreadedMemoryReactorClock() + self.reactor, self.clock = get_clock() self.mock_resolver = AsyncMock(spec=SrvResolver) @@ -98,6 +98,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.well_known_resolver = WellKnownResolver( server_name="OUR_STUB_HOMESERVER_NAME", reactor=self.reactor, + clock=self.clock, agent=Agent(self.reactor, contextFactory=self.tls_factory), user_agent=b"test-agent", well_known_cache=self.well_known_cache, @@ -280,6 +281,7 @@ class MatrixFederationAgentTests(unittest.TestCase): return MatrixFederationAgent( server_name="OUR_STUB_HOMESERVER_NAME", reactor=cast(ISynapseReactor, self.reactor), + clock=self.clock, tls_client_options_factory=self.tls_factory, user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided. ip_allowlist=IPSet(), @@ -1024,6 +1026,7 @@ class MatrixFederationAgentTests(unittest.TestCase): agent = MatrixFederationAgent( server_name="OUR_STUB_HOMESERVER_NAME", reactor=self.reactor, + clock=self.clock, tls_client_options_factory=tls_factory, user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. ip_allowlist=IPSet(), @@ -1033,6 +1036,7 @@ class MatrixFederationAgentTests(unittest.TestCase): _well_known_resolver=WellKnownResolver( server_name="OUR_STUB_HOMESERVER_NAME", reactor=cast(ISynapseReactor, self.reactor), + clock=self.clock, agent=Agent(self.reactor, contextFactory=tls_factory), user_agent=b"test-agent", well_known_cache=self.well_known_cache, diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py index 057ca0db45..31cdfacd2c 100644 --- a/tests/logging/test_opentracing.py +++ b/tests/logging/test_opentracing.py @@ -163,7 +163,9 @@ class TracingScopeTestCase(TestCase): # implements `ISynapseThreadlessReactor` (combination of the normal Twisted # Reactor/Clock interfaces), via inheritance from # `twisted.internet.testing.MemoryReactor` and `twisted.internet.testing.Clock` - clock = Clock( + # Ignore `multiple-internal-clocks` linter error here since we are creating a `Clock` + # for testing purposes. + clock = Clock( # type: ignore[multiple-internal-clocks] reactor, # type: ignore[arg-type] server_name="test_server", ) @@ -234,7 +236,9 @@ class TracingScopeTestCase(TestCase): # implements `ISynapseThreadlessReactor` (combination of the normal Twisted # Reactor/Clock interfaces), via inheritance from # `twisted.internet.testing.MemoryReactor` and `twisted.internet.testing.Clock` - clock = Clock( + # Ignore `multiple-internal-clocks` linter error here since we are creating a `Clock` + # for testing purposes. + clock = Clock( # type: ignore[multiple-internal-clocks] reactor, # type: ignore[arg-type] server_name="test_server", ) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 832e991730..b3f42c76f1 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -164,7 +164,10 @@ class CacheMetricsTests(unittest.HomeserverTestCase): """ CACHE_NAME = "cache_metrics_test_fgjkbdfg" cache: DeferredCache[str, str] = DeferredCache( - name=CACHE_NAME, server_name=self.hs.hostname, max_entries=777 + name=CACHE_NAME, + clock=self.hs.get_clock(), + server_name=self.hs.hostname, + max_entries=777, ) metrics_map = get_latest_metrics() @@ -212,10 +215,10 @@ class CacheMetricsTests(unittest.HomeserverTestCase): """ CACHE_NAME = "cache_metric_multiple_servers_test" cache1: DeferredCache[str, str] = DeferredCache( - name=CACHE_NAME, server_name="hs1", max_entries=777 + name=CACHE_NAME, clock=self.clock, server_name="hs1", max_entries=777 ) cache2: DeferredCache[str, str] = DeferredCache( - name=CACHE_NAME, server_name="hs2", max_entries=777 + name=CACHE_NAME, clock=self.clock, server_name="hs2", max_entries=777 ) metrics_map = get_latest_metrics() diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 36d3213908..1a2dab4c7d 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -173,7 +173,13 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Set up the server side protocol server_address = IPv4Address("TCP", host, port) - channel = self.site.buildProtocol((host, port)) + # The type ignore is here because mypy doesn't think the host/port tuple is of + # the correct type, even though it is the exact example given for + # `twisted.internet.interfaces.IAddress`. + # Mypy was happy with the type before we overrode `buildProtocol` in + # `SynapseSite`, probably because there was enough inheritance indirection before + # withe the argument not having a type associated with it. + channel = self.site.buildProtocol((host, port)) # type: ignore[arg-type] # hook into the channel's request factory so that we can keep a record # of the requests @@ -185,7 +191,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): requests.append(request) return request - channel.requestFactory = request_factory + channel.requestFactory = request_factory # type: ignore[method-assign] # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -427,7 +433,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Set up the server side protocol server_address = IPv4Address("TCP", host, port) - channel = self._hs_to_site[hs].buildProtocol((host, port)) + channel = self._hs_to_site[hs].buildProtocol((host, port)) # type: ignore[arg-type] # Connect client to server and vice versa. client_to_server_transport = FakeTransport( diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 92259f2542..3896e0ce8a 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -66,10 +66,11 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): def setUp(self) -> None: super().setUp() - reactor, _ = get_clock() + reactor, clock = get_clock() self.matrix_federation_agent = MatrixFederationAgent( server_name="OUR_STUB_HOMESERVER_NAME", reactor=reactor, + clock=clock, tls_client_options_factory=None, user_agent=b"SynapseInTrialTest/0.0.0", ip_allowlist=None, diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py index 8d5d0cce9a..1cb898673b 100644 --- a/tests/replication/test_module_cache_invalidation.py +++ b/tests/replication/test_module_cache_invalidation.py @@ -24,6 +24,7 @@ import synapse from synapse.module_api import cached from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import get_clock logger = logging.getLogger(__name__) @@ -36,6 +37,7 @@ KEY = "mykey" class TestCache: current_value = FIRST_VALUE server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() async def cached_function(self, user_id: str) -> str: diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index c22c1a6612..bb83988d76 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -93,8 +93,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase): ) -> Generator["defer.Deferred[Any]", object, None]: @defer.inlineCallbacks def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]: + # Ignore `multiple-internal-clocks` linter error here since we are creating a `Clock` + # for testing purposes. yield defer.ensureDeferred( - Clock(reactor, server_name="test_server").sleep(0) + Clock(reactor, server_name="test_server").sleep(0) # type: ignore[multiple-internal-clocks] ) return 1, {} diff --git a/tests/server.py b/tests/server.py index 226bdf4bbe..a9a53eb8a4 100644 --- a/tests/server.py +++ b/tests/server.py @@ -28,6 +28,7 @@ import sqlite3 import time import uuid import warnings +import weakref from collections import deque from io import SEEK_END, BytesIO from typing import ( @@ -56,7 +57,7 @@ from zope.interface import implementer import twisted from twisted.enterprise import adbapi -from twisted.internet import address, tcp, threads, udp +from twisted.internet import address, defer, tcp, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed @@ -524,6 +525,19 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): # overwrite it again. self.nameResolver = SimpleResolverComplexifier(FakeResolver()) + def run(self) -> None: + """ + Override the call from `MemoryReactorClock` to add an additional step that + cleans up any `whenRunningHooks` that have been called. + This is necessary for a clean shutdown to occur as these hooks can hold + references to the `SynapseHomeServer`. + """ + super().run() + + # `MemoryReactorClock` never clears the hooks that have already been called. + # So manually clear the hooks here after they have been run. + self.whenRunningHooks.clear() + def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver: raise NotImplementedError() @@ -649,6 +663,19 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): super().advance(0) +def cleanup_test_reactor_system_event_triggers( + reactor: ThreadedMemoryReactorClock, +) -> None: + """Cleanup any registered system event triggers. + The `twisted.internet.test.ThreadedMemoryReactor` does not implement + `removeSystemEventTrigger` so won't clean these triggers up on it's own properly. + When trying to override `removeSystemEventTrigger` in `ThreadedMemoryReactorClock` + in order to implement this functionality, twisted complains about the reactor being + unclean and fails some tests. + """ + reactor.triggers.clear() + + def validate_connector(connector: tcp.Connector, expected_ip: str) -> None: """Try to validate the obtained connector as it would happen when synapse is running and the conection will be established. @@ -780,13 +807,18 @@ class ThreadPool: d: "Deferred[None]" = Deferred() d.addCallback(lambda x: function(*args, **kwargs)) d.addBoth(_) - self._reactor.callLater(0, d.callback, True) + # mypy ignored here because: + # - this is part of the test infrastructure (outside of Synapse) so tracking + # these calls for for homeserver shutdown doesn't make sense. + self._reactor.callLater(0, d.callback, True) # type: ignore[call-later-not-tracked] return d def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: + # Ignore the linter error since this is an expected usage of creating a `Clock` for + # testing purposes. reactor = ThreadedMemoryReactorClock() - hs_clock = Clock(reactor, server_name="test_server") + hs_clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] return reactor, hs_clock @@ -898,10 +930,16 @@ class FakeTransport: # some implementations of IProducer (for example, FileSender) # don't return a deferred. d = maybeDeferred(self.producer.resumeProducing) - d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) + # mypy ignored here because: + # - this is part of the test infrastructure (outside of Synapse) so tracking + # these calls for for homeserver shutdown doesn't make sense. + d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) # type: ignore[call-later-not-tracked,call-overload] if not streaming: - self._reactor.callLater(0.0, _produce) + # mypy ignored here because: + # - this is part of the test infrastructure (outside of Synapse) so tracking + # these calls for for homeserver shutdown doesn't make sense. + self._reactor.callLater(0.0, _produce) # type: ignore[call-later-not-tracked] def write(self, byt: bytes) -> None: if self.disconnecting: @@ -913,7 +951,10 @@ class FakeTransport: # TLSMemoryBIOProtocol) get very confused if a read comes back while they are # still doing a write. Doing a callLater here breaks the cycle. if self.autoflush: - self._reactor.callLater(0.0, self.flush) + # mypy ignored here because: + # - this is part of the test infrastructure (outside of Synapse) so tracking + # these calls for for homeserver shutdown doesn't make sense. + self._reactor.callLater(0.0, self.flush) # type: ignore[call-later-not-tracked] def writeSequence(self, seq: Iterable[bytes]) -> None: for x in seq: @@ -943,7 +984,10 @@ class FakeTransport: self.buffer = self.buffer[len(to_write) :] if self.buffer and self.autoflush: - self._reactor.callLater(0.0, self.flush) + # mypy ignored here because: + # - this is part of the test infrastructure (outside of Synapse) so tracking + # these calls for for homeserver shutdown doesn't make sense. + self._reactor.callLater(0.0, self.flush) # type: ignore[call-later-not-tracked] if not self.buffer and self.disconnecting: logger.info("FakeTransport: Buffer now empty, completing disconnect") @@ -1020,7 +1064,7 @@ class TestHomeServer(HomeServer): def setup_test_homeserver( *, - cleanup_func: Callable[[Callable[[], None]], None], + cleanup_func: Callable[[Callable[[], Optional["Deferred[None]"]]], None], server_name: str = "test", config: Optional[HomeServerConfig] = None, reactor: Optional[ISynapseReactor] = None, @@ -1035,8 +1079,10 @@ def setup_test_homeserver( If no datastore is supplied, one is created and given to the homeserver. Args: - cleanup_func: The function used to register a cleanup routine for after the - test. + cleanup_func : The function used to register a cleanup routine for + after the test. If the function returns a Deferred, the + test case will wait until the Deferred has fired before + proceeding to the next cleanup function. server_name: Homeserver name config: Homeserver config reactor: Twisted reactor @@ -1062,7 +1108,9 @@ def setup_test_homeserver( raise ConfigError("Must be a string", ("server_name",)) if "clock" not in extra_homeserver_attributes: - extra_homeserver_attributes["clock"] = Clock(reactor, server_name=server_name) + # Ignore `multiple-internal-clocks` linter error here since we are creating a `Clock` + # for testing purposes (i.e. outside of Synapse). + extra_homeserver_attributes["clock"] = Clock(reactor, server_name=server_name) # type: ignore[multiple-internal-clocks] config.caches.resize_all_caches() @@ -1154,8 +1202,21 @@ def setup_test_homeserver( reactor=reactor, ) - # Register the cleanup hook - cleanup_func(hs.cleanup) + # Capture the `hs` as a `weakref` here to ensure there is no scenario where uncalled + # cleanup functions result in holding the `hs` in memory. + cleanup_hs_ref = weakref.ref(hs) + + def shutdown_hs_on_cleanup() -> "Deferred[None]": + cleanup_hs = cleanup_hs_ref() + deferred: "Deferred[None]" = defer.succeed(None) + if cleanup_hs is not None: + deferred = defer.ensureDeferred(cleanup_hs.shutdown()) + return deferred + + # Register the cleanup hook for the homeserver. + # A full `hs.shutdown()` is necessary otherwise CI tests will fail while exhibiting + # strange behaviours. + cleanup_func(shutdown_hs_on_cleanup) # Install @cache_in_self attributes for key, val in extra_homeserver_attributes.items(): @@ -1184,14 +1245,18 @@ def setup_test_homeserver( hs.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False if USE_POSTGRES_FOR_TESTS: - database_pool = hs.get_datastores().databases[0] + # Capture the `database_pool` as a `weakref` here to ensure there is no scenario where uncalled + # cleanup functions result in holding the `hs` in memory. + database_pool = weakref.ref(hs.get_datastores().databases[0]) # We need to do cleanup on PostgreSQL def cleanup() -> None: import psycopg2 # Close all the db pools - database_pool._db_pool.close() + db_pool = database_pool() + if db_pool is not None: + db_pool._db_pool.close() dropped = False diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 19dafe64ed..2dd26833c8 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -26,9 +26,10 @@ from synapse.util.distributor import Distributor from . import unittest -class DistributorTestCase(unittest.TestCase): +class DistributorTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: - self.dist = Distributor(server_name="test_server") + super().setUp() + self.dist = Distributor(hs=self.hs) def test_signal_dispatch(self) -> None: self.dist.declare("alert") diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 7017d6d70a..f0deb1554e 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -26,20 +26,26 @@ from twisted.internet import defer from synapse.util.caches.deferred_cache import DeferredCache +from tests.server import get_clock from tests.unittest import TestCase class DeferredCacheTestCase(TestCase): + def setUp(self) -> None: + super().setUp() + + _, self.clock = get_clock() + def test_empty(self) -> None: cache: DeferredCache[str, int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) with self.assertRaises(KeyError): cache.get("foo") def test_hit(self) -> None: cache: DeferredCache[str, int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) cache.prefill("foo", 123) @@ -47,7 +53,7 @@ class DeferredCacheTestCase(TestCase): def test_hit_deferred(self) -> None: cache: DeferredCache[str, int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) origin_d: "defer.Deferred[int]" = defer.Deferred() set_d = cache.set("k1", origin_d) @@ -72,7 +78,7 @@ class DeferredCacheTestCase(TestCase): def test_callbacks(self) -> None: """Invalidation callbacks are called at the right time""" cache: DeferredCache[str, int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) callbacks = set() @@ -107,7 +113,7 @@ class DeferredCacheTestCase(TestCase): def test_set_fail(self) -> None: cache: DeferredCache[str, int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) callbacks = set() @@ -146,7 +152,7 @@ class DeferredCacheTestCase(TestCase): def test_get_immediate(self) -> None: cache: DeferredCache[str, int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) d1: "defer.Deferred[int]" = defer.Deferred() cache.set("key1", d1) @@ -164,7 +170,7 @@ class DeferredCacheTestCase(TestCase): def test_invalidate(self) -> None: cache: DeferredCache[Tuple[str], int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) cache.prefill(("foo",), 123) cache.invalidate(("foo",)) @@ -174,7 +180,7 @@ class DeferredCacheTestCase(TestCase): def test_invalidate_all(self) -> None: cache: DeferredCache[str, str] = DeferredCache( - name="testcache", server_name="test_server" + name="testcache", clock=self.clock, server_name="test_server" ) callback_record = [False, False] @@ -220,6 +226,7 @@ class DeferredCacheTestCase(TestCase): def test_eviction(self) -> None: cache: DeferredCache[int, str] = DeferredCache( name="test", + clock=self.clock, server_name="test_server", max_entries=2, apply_cache_factor_from_config=False, @@ -238,6 +245,7 @@ class DeferredCacheTestCase(TestCase): def test_eviction_lru(self) -> None: cache: DeferredCache[int, str] = DeferredCache( name="test", + clock=self.clock, server_name="test_server", max_entries=2, apply_cache_factor_from_config=False, @@ -260,6 +268,7 @@ class DeferredCacheTestCase(TestCase): def test_eviction_iterable(self) -> None: cache: DeferredCache[int, List[str]] = DeferredCache( name="test", + clock=self.clock, server_name="test_server", max_entries=3, apply_cache_factor_from_config=False, diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 3eb502f902..0e3b6ae36b 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -49,6 +49,7 @@ from synapse.util.caches import descriptors from synapse.util.caches.descriptors import _CacheContext, cached, cachedList from tests import unittest +from tests.server import get_clock from tests.test_utils import get_awaitable_result logger = logging.getLogger(__name__) @@ -56,7 +57,10 @@ logger = logging.getLogger(__name__) def run_on_reactor() -> "Deferred[int]": d: "Deferred[int]" = Deferred() - cast(IReactorTime, reactor).callLater(0, d.callback, 0) + # mypy ignored here because: + # - this is part of the test infrastructure (outside of Synapse) so tracking + # these calls for for homeserver shutdown doesn't make sense. + cast(IReactorTime, reactor).callLater(0, d.callback, 0) # type: ignore[call-later-not-tracked] return make_deferred_yieldable(d) @@ -67,6 +71,7 @@ class DescriptorTestCase(unittest.TestCase): def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int, arg2: int) -> str: @@ -102,6 +107,7 @@ class DescriptorTestCase(unittest.TestCase): def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached(num_args=1) def fn(self, arg1: int, arg2: int) -> str: @@ -148,6 +154,7 @@ class DescriptorTestCase(unittest.TestCase): def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached obj = Cls() obj.mock.return_value = "fish" @@ -179,6 +186,7 @@ class DescriptorTestCase(unittest.TestCase): def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int, kwarg1: int = 2) -> str: @@ -214,6 +222,7 @@ class DescriptorTestCase(unittest.TestCase): class Cls: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def fn(self, arg1: int) -> NoReturn: @@ -239,6 +248,7 @@ class DescriptorTestCase(unittest.TestCase): result: Optional[Deferred] = None call_count = 0 server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def fn(self, arg1: int) -> Deferred: @@ -293,6 +303,7 @@ class DescriptorTestCase(unittest.TestCase): class Cls: server_name = "test_server" + _, clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int) -> "Deferred[int]": @@ -337,6 +348,7 @@ class DescriptorTestCase(unittest.TestCase): class Cls: server_name = "test_server" + _, clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int) -> Deferred: @@ -381,6 +393,7 @@ class DescriptorTestCase(unittest.TestCase): def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int, arg2: int = 2, arg3: int = 3) -> str: @@ -419,6 +432,7 @@ class DescriptorTestCase(unittest.TestCase): def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached(iterable=True) def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]: @@ -453,6 +467,7 @@ class DescriptorTestCase(unittest.TestCase): class Cls: server_name = "test_server" + _, clock = get_clock() # nb must be called this for @cached @descriptors.cached(iterable=True) def fn(self, arg1: int) -> NoReturn: @@ -476,6 +491,7 @@ class DescriptorTestCase(unittest.TestCase): class Cls: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached(cache_context=True) async def func1(self, key: str, cache_context: _CacheContext) -> int: @@ -504,6 +520,7 @@ class DescriptorTestCase(unittest.TestCase): class Cls: server_name = "test_server" + _, clock = get_clock() # nb must be called this for @cached @cached() async def fn(self, arg1: int) -> str: @@ -537,6 +554,7 @@ class DescriptorTestCase(unittest.TestCase): class Cls: inner_context_was_finished = False server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() async def fn(self, arg1: int) -> str: @@ -583,6 +601,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): def test_passthrough(self) -> Generator["Deferred[Any]", object, None]: class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> str: @@ -599,6 +618,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> str: @@ -619,6 +639,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> str: @@ -639,6 +660,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): def test_invalidate_missing(self) -> None: class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> str: @@ -652,6 +674,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached(max_entries=10) def func(self, key: int) -> int: @@ -681,6 +704,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> "Deferred[int]": @@ -701,6 +725,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> str: @@ -736,6 +761,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached(max_entries=2) def func(self, key: str) -> str: @@ -775,6 +801,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> str: @@ -824,6 +851,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int, arg2: int) -> None: @@ -890,6 +918,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int) -> None: @@ -934,6 +963,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int, arg2: int) -> None: @@ -975,6 +1005,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): class Cls: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def fn(self, arg1: int) -> None: @@ -1011,6 +1042,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): class Cls: inner_context_was_finished = False server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def fn(self, arg1: int) -> None: @@ -1055,6 +1087,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): class Cls: server_name = "test_server" + _, clock = get_clock() # nb must be called this for @cached @descriptors.cached(tree=True) def fn(self, room_id: str, event_id: str) -> None: diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index 54f7b55511..fd8d576aea 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -25,7 +25,6 @@ from parameterized import parameterized_class from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred, ensureDeferred -from twisted.internet.task import Clock from twisted.python.failure import Failure from synapse.logging.context import ( @@ -152,7 +151,7 @@ class ObservableDeferredTest(TestCase): class TimeoutDeferredTest(TestCase): def setUp(self) -> None: - self.clock = Clock() + self.reactor, self.clock = get_clock() def test_times_out(self) -> None: """Basic test case that checks that the original deferred is cancelled and that @@ -165,12 +164,16 @@ class TimeoutDeferredTest(TestCase): cancelled = True non_completing_d: Deferred = Deferred(canceller) - timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) + timing_out_d = timeout_deferred( + deferred=non_completing_d, + timeout=1.0, + clock=self.clock, + ) self.assertNoResult(timing_out_d) self.assertFalse(cancelled, "deferred was cancelled prematurely") - self.clock.pump((1.0,)) + self.reactor.pump((1.0,)) self.assertTrue(cancelled, "deferred was not cancelled by timeout") self.failureResultOf(timing_out_d, defer.TimeoutError) @@ -183,11 +186,15 @@ class TimeoutDeferredTest(TestCase): raise Exception("can't cancel this deferred") non_completing_d: Deferred = Deferred(canceller) - timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) + timing_out_d = timeout_deferred( + deferred=non_completing_d, + timeout=1.0, + clock=self.clock, + ) self.assertNoResult(timing_out_d) - self.clock.pump((1.0,)) + self.reactor.pump((1.0,)) self.failureResultOf(timing_out_d, defer.TimeoutError) @@ -227,7 +234,7 @@ class TimeoutDeferredTest(TestCase): timing_out_d = timeout_deferred( deferred=incomplete_d, timeout=1.0, - reactor=self.clock, + clock=self.clock, ) self.assertNoResult(timing_out_d) # We should still be in the logcontext we started in @@ -243,7 +250,7 @@ class TimeoutDeferredTest(TestCase): # we're pumping the reactor in the block and return us back to our current # logcontext after the block. with PreserveLoggingContext(): - self.clock.pump( + self.reactor.pump( # We only need to pump `1.0` (seconds) as we set # `timeout_deferred(timeout=1.0)` above (1.0,) @@ -264,7 +271,7 @@ class TimeoutDeferredTest(TestCase): self.assertEqual(current_context(), SENTINEL_CONTEXT) -class _TestException(Exception): +class _TestException(Exception): # pass @@ -560,8 +567,8 @@ class AwakenableSleeperTests(TestCase): "Tests AwakenableSleeper" def test_sleep(self) -> None: - reactor, _ = get_clock() - sleeper = AwakenableSleeper(reactor) + reactor, clock = get_clock() + sleeper = AwakenableSleeper(clock) d = defer.ensureDeferred(sleeper.sleep("name", 1000)) @@ -575,8 +582,8 @@ class AwakenableSleeperTests(TestCase): self.assertTrue(d.called) def test_explicit_wake(self) -> None: - reactor, _ = get_clock() - sleeper = AwakenableSleeper(reactor) + reactor, clock = get_clock() + sleeper = AwakenableSleeper(clock) d = defer.ensureDeferred(sleeper.sleep("name", 1000)) @@ -592,8 +599,8 @@ class AwakenableSleeperTests(TestCase): reactor.advance(0.6) def test_multiple_sleepers_timeout(self) -> None: - reactor, _ = get_clock() - sleeper = AwakenableSleeper(reactor) + reactor, clock = get_clock() + sleeper = AwakenableSleeper(clock) d1 = defer.ensureDeferred(sleeper.sleep("name", 1000)) @@ -612,8 +619,8 @@ class AwakenableSleeperTests(TestCase): self.assertTrue(d2.called) def test_multiple_sleepers_wake(self) -> None: - reactor, _ = get_clock() - sleeper = AwakenableSleeper(reactor) + reactor, clock = get_clock() + sleeper = AwakenableSleeper(clock) d1 = defer.ensureDeferred(sleeper.sleep("name", 1000)) diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py index 532582cf87..60bfdf38aa 100644 --- a/tests/util/test_batching_queue.py +++ b/tests/util/test_batching_queue.py @@ -32,13 +32,12 @@ from synapse.util.batching_queue import ( number_queued, ) -from tests.server import get_clock -from tests.unittest import TestCase +from tests.unittest import HomeserverTestCase -class BatchingQueueTestCase(TestCase): +class BatchingQueueTestCase(HomeserverTestCase): def setUp(self) -> None: - self.clock, hs_clock = get_clock() + super().setUp() # We ensure that we remove any existing metrics for "test_queue". try: @@ -51,8 +50,8 @@ class BatchingQueueTestCase(TestCase): self._pending_calls: List[Tuple[List[str], defer.Deferred]] = [] self.queue: BatchingQueue[str, str] = BatchingQueue( name="test_queue", - server_name="test_server", - clock=hs_clock, + hs=self.hs, + clock=self.clock, process_batch_callback=self._process_queue, ) @@ -108,7 +107,7 @@ class BatchingQueueTestCase(TestCase): self.assertFalse(queue_d.called) # We should see a call to `_process_queue` after a reactor tick. - self.clock.pump([0]) + self.reactor.pump([0]) self.assertEqual(len(self._pending_calls), 1) self.assertEqual(self._pending_calls[0][0], ["foo"]) @@ -134,7 +133,7 @@ class BatchingQueueTestCase(TestCase): self._assert_metrics(queued=2, keys=1, in_flight=2) - self.clock.pump([0]) + self.reactor.pump([0]) # We should see only *one* call to `_process_queue` self.assertEqual(len(self._pending_calls), 1) @@ -158,7 +157,7 @@ class BatchingQueueTestCase(TestCase): self.assertFalse(self._pending_calls) queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1")) - self.clock.pump([0]) + self.reactor.pump([0]) self.assertEqual(len(self._pending_calls), 1) @@ -185,7 +184,7 @@ class BatchingQueueTestCase(TestCase): self._assert_metrics(queued=2, keys=1, in_flight=2) # We should now see a second call to `_process_queue` - self.clock.pump([0]) + self.reactor.pump([0]) self.assertEqual(len(self._pending_calls), 1) self.assertEqual(self._pending_calls[0][0], ["foo2", "foo3"]) self.assertFalse(queue_d2.called) @@ -206,9 +205,9 @@ class BatchingQueueTestCase(TestCase): self.assertFalse(self._pending_calls) queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1", key=1)) - self.clock.pump([0]) + self.reactor.pump([0]) queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2", key=2)) - self.clock.pump([0]) + self.reactor.pump([0]) # We queue up another item with key=2 to check that we will keep taking # things off the queue. @@ -240,7 +239,7 @@ class BatchingQueueTestCase(TestCase): self.assertFalse(queue_d3.called) # We should now see a call `_pending_calls` for `foo3` - self.clock.pump([0]) + self.reactor.pump([0]) self.assertEqual(len(self._pending_calls), 1) self.assertEqual(self._pending_calls[0][0], ["foo3"]) self.assertFalse(queue_d3.called) diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index 246e18fd15..16e096a4b2 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -23,12 +23,14 @@ from synapse.util.caches.dictionary_cache import DictionaryCache from tests import unittest +from tests.server import get_clock class DictCacheTestCase(unittest.TestCase): def setUp(self) -> None: + _, clock = get_clock() self.cache: DictionaryCache[str, str, str] = DictionaryCache( - name="foobar", server_name="test_server", max_entries=10 + name="foobar", clock=clock, server_name="test_server", max_entries=10 ) def test_simple_cache_hit_full(self) -> None: diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index eda2d586f6..35c0f02e3f 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -34,6 +34,7 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): cache: ExpiringCache[str, str] = ExpiringCache( cache_name="test", server_name="testserver", + hs=self.hs, clock=clock, max_len=1, ) @@ -47,6 +48,7 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): cache: ExpiringCache[str, str] = ExpiringCache( cache_name="test", server_name="testserver", + hs=self.hs, clock=clock, max_len=2, ) @@ -66,6 +68,7 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): cache: ExpiringCache[str, List[int]] = ExpiringCache( cache_name="test", server_name="testserver", + hs=self.hs, clock=clock, max_len=5, iterable=True, @@ -90,6 +93,7 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): cache: ExpiringCache[str, int] = ExpiringCache( cache_name="test", server_name="testserver", + hs=self.hs, clock=clock, expiry_ms=1000, ) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 966ea31f1a..ca805bb20a 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -66,7 +66,8 @@ class LoggingContextTestCase(unittest.TestCase): """ Test `Clock.sleep` """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -90,7 +91,7 @@ class LoggingContextTestCase(unittest.TestCase): # so that the test can complete and we see the underlying error. callback_finished = True - reactor.callLater(0, lambda: defer.ensureDeferred(competing_callback())) + reactor.callLater(0, lambda: defer.ensureDeferred(competing_callback())) # type: ignore[call-later-not-tracked] with LoggingContext(name="foo", server_name="test_server"): await clock.sleep(0) @@ -111,7 +112,8 @@ class LoggingContextTestCase(unittest.TestCase): """ Test `Clock.looping_call` """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -161,7 +163,8 @@ class LoggingContextTestCase(unittest.TestCase): """ Test `Clock.looping_call_now` """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -209,7 +212,8 @@ class LoggingContextTestCase(unittest.TestCase): """ Test `Clock.call_later` """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -261,7 +265,8 @@ class LoggingContextTestCase(unittest.TestCase): `d.callback(None)` without anything else. See the *Deferred callbacks* section of docs/log_contexts.md for more details. """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -318,7 +323,8 @@ class LoggingContextTestCase(unittest.TestCase): `d.callback(None)` without anything else. See the *Deferred callbacks* section of docs/log_contexts.md for more details. """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -379,7 +385,8 @@ class LoggingContextTestCase(unittest.TestCase): `d.callback(None)` without anything else. See the *Deferred callbacks* section of docs/log_contexts.md for more details. """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -450,7 +457,8 @@ class LoggingContextTestCase(unittest.TestCase): self._check_test_key("sentinel") async def _test_run_in_background(self, function: Callable[[], object]) -> None: - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -492,7 +500,8 @@ class LoggingContextTestCase(unittest.TestCase): @logcontext_clean async def test_run_in_background_with_blocking_fn(self) -> None: async def blocking_function() -> None: - await Clock(reactor, server_name="test_server").sleep(0) + # Ignore linter error since we are creating a `Clock` for testing purposes. + await Clock(reactor, server_name="test_server").sleep(0) # type: ignore[multiple-internal-clocks] await self._test_run_in_background(blocking_function) @@ -525,7 +534,8 @@ class LoggingContextTestCase(unittest.TestCase): async def testfunc() -> None: self._check_test_key("foo") - d = defer.ensureDeferred(Clock(reactor, server_name="test_server").sleep(0)) + # Ignore linter error since we are creating a `Clock` for testing purposes. + d = defer.ensureDeferred(Clock(reactor, server_name="test_server").sleep(0)) # type: ignore[multiple-internal-clocks] self.assertIs(current_context(), SENTINEL_CONTEXT) await d self._check_test_key("foo") @@ -554,7 +564,8 @@ class LoggingContextTestCase(unittest.TestCase): This will stress the logic around incomplete deferreds in `run_coroutine_in_background`. """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -645,7 +656,7 @@ class LoggingContextTestCase(unittest.TestCase): # the synapse rules. def blocking_function() -> defer.Deferred: d: defer.Deferred = defer.Deferred() - reactor.callLater(0, d.callback, None) + reactor.callLater(0, d.callback, None) # type: ignore[call-later-not-tracked] return d sentinel_context = current_context() @@ -692,7 +703,7 @@ def _chained_deferred_function() -> defer.Deferred: def cb(res: object) -> defer.Deferred: d2: defer.Deferred = defer.Deferred() - reactor.callLater(0, d2.callback, res) + reactor.callLater(0, d2.callback, res) # type: ignore[call-later-not-tracked] return d2 d.addCallback(cb) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 4d37ad0975..56e9996b00 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -29,18 +29,28 @@ from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entrie from synapse.util.caches.treecache import TreeCache from tests import unittest +from tests.server import get_clock from tests.unittest import override_config class LruCacheTestCase(unittest.HomeserverTestCase): + def setUp(self) -> None: + super().setUp() + + _, self.clock = get_clock() + def test_get_set(self) -> None: - cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache["key"] = "value" self.assertEqual(cache.get("key"), "value") self.assertEqual(cache["key"], "value") def test_eviction(self) -> None: - cache: LruCache[int, int] = LruCache(max_size=2, server_name="test_server") + cache: LruCache[int, int] = LruCache( + max_size=2, clock=self.clock, server_name="test_server" + ) cache[1] = 1 cache[2] = 2 @@ -54,7 +64,9 @@ class LruCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get(3), 3) def test_setdefault(self) -> None: - cache: LruCache[str, int] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, int] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) self.assertEqual(cache.setdefault("key", 1), 1) self.assertEqual(cache.get("key"), 1) self.assertEqual(cache.setdefault("key", 2), 1) @@ -63,7 +75,9 @@ class LruCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get("key"), 2) def test_pop(self) -> None: - cache: LruCache[str, int] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, int] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache["key"] = 1 self.assertEqual(cache.pop("key"), 1) self.assertEqual(cache.pop("key"), None) @@ -71,7 +85,10 @@ class LruCacheTestCase(unittest.HomeserverTestCase): def test_del_multi(self) -> None: # The type here isn't quite correct as they don't handle TreeCache well. cache: LruCache[Tuple[str, str], str] = LruCache( - max_size=4, cache_type=TreeCache, server_name="test_server" + max_size=4, + clock=self.clock, + cache_type=TreeCache, + server_name="test_server", ) cache[("animal", "cat")] = "mew" cache[("animal", "dog")] = "woof" @@ -91,7 +108,9 @@ class LruCacheTestCase(unittest.HomeserverTestCase): # Man from del_multi say "Yes". def test_clear(self) -> None: - cache: LruCache[str, int] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, int] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache["key"] = 1 cache.clear() self.assertEqual(len(cache), 0) @@ -99,7 +118,10 @@ class LruCacheTestCase(unittest.HomeserverTestCase): @override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) def test_special_size(self) -> None: cache: LruCache = LruCache( - max_size=10, server_name="test_server", cache_name="mycache" + max_size=10, + clock=self.clock, + server_name="test_server", + cache_name="mycache", ) self.assertEqual(cache.max_size, 100) @@ -107,7 +129,9 @@ class LruCacheTestCase(unittest.HomeserverTestCase): class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_get(self) -> None: m = Mock() - cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache.set("key", "value") self.assertFalse(m.called) @@ -126,7 +150,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_multi_get(self) -> None: m = Mock() - cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache.set("key", "value") self.assertFalse(m.called) @@ -145,7 +171,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_set(self) -> None: m = Mock() - cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) @@ -161,7 +189,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_pop(self) -> None: m = Mock() - cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) @@ -182,7 +212,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): m4 = Mock() # The type here isn't quite correct as they don't handle TreeCache well. cache: LruCache[Tuple[str, str], str] = LruCache( - max_size=4, cache_type=TreeCache, server_name="test_server" + max_size=4, + clock=self.clock, + cache_type=TreeCache, + server_name="test_server", ) cache.set(("a", "1"), "value", callbacks=[m1]) @@ -205,7 +238,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_clear(self) -> None: m1 = Mock() m2 = Mock() - cache: LruCache[str, str] = LruCache(max_size=5, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=5, clock=self.clock, server_name="test_server" + ) cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) @@ -222,7 +257,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): m1 = Mock(name="m1") m2 = Mock(name="m2") m3 = Mock(name="m3") - cache: LruCache[str, str] = LruCache(max_size=2, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=2, clock=self.clock, server_name="test_server" + ) cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) @@ -259,7 +296,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): class LruCacheSizedTestCase(unittest.HomeserverTestCase): def test_evict(self) -> None: cache: LruCache[str, List[int]] = LruCache( - max_size=5, size_callback=len, server_name="test_server" + max_size=5, clock=self.clock, size_callback=len, server_name="test_server" ) cache["key1"] = [0] cache["key2"] = [1, 2] @@ -284,7 +321,10 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): def test_zero_size_drop_from_cache(self) -> None: """Test that `drop_from_cache` works correctly with 0-sized entries.""" cache: LruCache[str, List[int]] = LruCache( - max_size=5, size_callback=lambda x: 0, server_name="test_server" + max_size=5, + clock=self.clock, + size_callback=lambda x: 0, + server_name="test_server", ) cache["key1"] = [] @@ -402,7 +442,10 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase): class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase): def test_invalidate_simple(self) -> None: cache: LruCache[str, int] = LruCache( - max_size=10, server_name="test_server", extra_index_cb=lambda k, v: str(v) + max_size=10, + clock=self.hs.get_clock(), + server_name="test_server", + extra_index_cb=lambda k, v: str(v), ) cache["key1"] = 1 cache["key2"] = 2 @@ -417,7 +460,10 @@ class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase): def test_invalidate_multi(self) -> None: cache: LruCache[str, int] = LruCache( - max_size=10, server_name="test_server", extra_index_cb=lambda k, v: str(v) + max_size=10, + clock=self.hs.get_clock(), + server_name="test_server", + extra_index_cb=lambda k, v: str(v), ) cache["key1"] = 1 cache["key2"] = 1 diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 82baff5883..593be93ea3 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -35,6 +35,7 @@ class RetryLimiterTestCase(HomeserverTestCase): get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ) @@ -57,6 +58,7 @@ class RetryLimiterTestCase(HomeserverTestCase): get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ) @@ -89,6 +91,7 @@ class RetryLimiterTestCase(HomeserverTestCase): get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ), @@ -104,6 +107,7 @@ class RetryLimiterTestCase(HomeserverTestCase): get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ) @@ -139,6 +143,7 @@ class RetryLimiterTestCase(HomeserverTestCase): get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ) @@ -165,6 +170,7 @@ class RetryLimiterTestCase(HomeserverTestCase): get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, notifier=notifier, @@ -238,6 +244,7 @@ class RetryLimiterTestCase(HomeserverTestCase): get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ) @@ -261,6 +268,7 @@ class RetryLimiterTestCase(HomeserverTestCase): get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ), @@ -273,6 +281,7 @@ class RetryLimiterTestCase(HomeserverTestCase): get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ) @@ -297,6 +306,7 @@ class RetryLimiterTestCase(HomeserverTestCase): get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ),