diff --git a/meshchatx/src/backend/database/provider.py b/meshchatx/src/backend/database/provider.py index e60e294..afd04af 100644 --- a/meshchatx/src/backend/database/provider.py +++ b/meshchatx/src/backend/database/provider.py @@ -12,6 +12,7 @@ class DatabaseProvider: self.db_path = db_path self._local = threading.local() self._all_locals.add(self._local) + self._memory_connection = None @classmethod def get_instance(cls, db_path=None): @@ -29,6 +30,21 @@ class DatabaseProvider: @property def connection(self): + # In-memory databases are private to the connection. + # If we use threading.local(), each thread gets a DIFFERENT in-memory database. + # For :memory:, we must share the connection across threads. + if self.db_path == ":memory:": + if self._memory_connection is None: + with self._lock: + if self._memory_connection is None: + self._memory_connection = sqlite3.connect( + self.db_path, + check_same_thread=False, + isolation_level=None, + ) + self._memory_connection.row_factory = sqlite3.Row + return self._memory_connection + if not hasattr(self._local, "connection"): # isolation_level=None enables autocommit mode, letting us manage transactions manually self._local.connection = sqlite3.connect( @@ -38,7 +54,12 @@ class DatabaseProvider: ) self._local.connection.row_factory = sqlite3.Row # Enable WAL mode for better concurrency - self._local.connection.execute("PRAGMA journal_mode=WAL") + if self.db_path != ":memory:": + try: + self._local.connection.execute("PRAGMA journal_mode=WAL") + except sqlite3.OperationalError: + # Some environments might not support WAL + pass return self._local.connection def execute(self, query, params=None, commit=None): @@ -115,6 +136,14 @@ class DatabaseProvider: return [dict(row) for row in rows] def close(self): + if self.db_path == ":memory:" and self._memory_connection: + try: + self._memory_connection.commit() + self._memory_connection.close() + except Exception: # noqa: S110 + pass + self._memory_connection = None + if hasattr(self._local, "connection"): try: self.commit() # Ensure everything is saved @@ -125,6 +154,14 @@ class DatabaseProvider: def close_all(self): with self._lock: + if self._memory_connection: + try: + self._memory_connection.commit() + self._memory_connection.close() + except Exception: # noqa: S110 + pass + self._memory_connection = None + for loc in self._all_locals: if hasattr(loc, "connection"): try: