This commit is contained in:
Matthew Hodgson
2026-03-24 13:50:14 -04:00
parent 78beff2b31
commit 7b7dda3879
3 changed files with 16 additions and 0 deletions

View File

@@ -607,6 +607,8 @@ class DatabasePool:
# Create a new in-memory DB and copy schema+data from the template
fresh_conn = sqlite3.connect(":memory:", check_same_thread=False)
source_conn.backup(fresh_conn)
# Re-register custom functions that don't survive backup()
engine.register_custom_functions(fresh_conn)
initial_conn = fresh_conn
self._db_pool = NativeConnectionPool(

View File

@@ -84,6 +84,12 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
@abc.abstractmethod
def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: ...
def register_custom_functions(self, raw_conn: "Any") -> None:
"""Register custom database functions on a raw connection.
Override in subclasses that need custom functions (e.g. SQLite rank()).
"""
@abc.abstractmethod
def is_deadlock(self, error: Exception) -> bool: ...

View File

@@ -86,6 +86,14 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
def convert_param_style(self, sql: str) -> str:
return sql
def register_custom_functions(self, raw_conn: sqlite3.Connection) -> None:
"""Register custom SQLite functions on a raw connection.
This must be called on any connection created outside the normal
on_new_connection path (e.g. connections created via backup()).
"""
raw_conn.create_function("rank", 1, _rank)
def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
# We need to import here to avoid an import loop.
from synapse.storage.prepare_database import prepare_database