Add in-memory database connection handling in DatabaseProvider

This commit is contained in:
Sudo-Ivan
2026-01-14 17:45:37 -06:00
parent 6a4ed6a048
commit 7e57cc2b24
+38 -1
View File
@@ -12,6 +12,7 @@ class DatabaseProvider:
self.db_path = db_path
self._local = threading.local()
self._all_locals.add(self._local)
self._memory_connection = None
@classmethod
def get_instance(cls, db_path=None):
@@ -29,6 +30,21 @@ class DatabaseProvider:
@property
def connection(self):
# In-memory databases are private to the connection.
# If we use threading.local(), each thread gets a DIFFERENT in-memory database.
# For :memory:, we must share the connection across threads.
if self.db_path == ":memory:":
if self._memory_connection is None:
with self._lock:
if self._memory_connection is None:
self._memory_connection = sqlite3.connect(
self.db_path,
check_same_thread=False,
isolation_level=None,
)
self._memory_connection.row_factory = sqlite3.Row
return self._memory_connection
if not hasattr(self._local, "connection"):
# isolation_level=None enables autocommit mode, letting us manage transactions manually
self._local.connection = sqlite3.connect(
@@ -38,7 +54,12 @@ class DatabaseProvider:
)
self._local.connection.row_factory = sqlite3.Row
# Enable WAL mode for better concurrency
self._local.connection.execute("PRAGMA journal_mode=WAL")
if self.db_path != ":memory:":
try:
self._local.connection.execute("PRAGMA journal_mode=WAL")
except sqlite3.OperationalError:
# Some environments might not support WAL
pass
return self._local.connection
def execute(self, query, params=None, commit=None):
@@ -115,6 +136,14 @@ class DatabaseProvider:
return [dict(row) for row in rows]
def close(self):
if self.db_path == ":memory:" and self._memory_connection:
try:
self._memory_connection.commit()
self._memory_connection.close()
except Exception: # noqa: S110
pass
self._memory_connection = None
if hasattr(self._local, "connection"):
try:
self.commit() # Ensure everything is saved
@@ -125,6 +154,14 @@ class DatabaseProvider:
def close_all(self):
with self._lock:
if self._memory_connection:
try:
self._memory_connection.commit()
self._memory_connection.close()
except Exception: # noqa: S110
pass
self._memory_connection = None
for loc in self._all_locals:
if hasattr(loc, "connection"):
try: