diff --git a/meshchatx/src/backend/database/legacy_migrator.py b/meshchatx/src/backend/database/legacy_migrator.py index dbe61f6..5c06675 100644 --- a/meshchatx/src/backend/database/legacy_migrator.py +++ b/meshchatx/src/backend/database/legacy_migrator.py @@ -116,9 +116,12 @@ class LegacyMigrator: ) ] - # Find common columns + # Find common columns, but exclude 'id' to avoid collisions during migration + # as new databases will have their own autoincrement IDs. common_columns = [ - col for col in legacy_columns if col in current_columns + col + for col in legacy_columns + if col in current_columns and col.lower() != "id" ] if common_columns: diff --git a/meshchatx/src/backend/database/provider.py b/meshchatx/src/backend/database/provider.py index afd04af..1ac8f69 100644 --- a/meshchatx/src/backend/database/provider.py +++ b/meshchatx/src/backend/database/provider.py @@ -49,6 +49,7 @@ class DatabaseProvider: # isolation_level=None enables autocommit mode, letting us manage transactions manually self._local.connection = sqlite3.connect( self.db_path, + timeout=30.0, check_same_thread=False, isolation_level=None, ) diff --git a/meshchatx/src/backend/identity_manager.py b/meshchatx/src/backend/identity_manager.py index d6fe088..2dc4add 100644 --- a/meshchatx/src/backend/identity_manager.py +++ b/meshchatx/src/backend/identity_manager.py @@ -186,9 +186,8 @@ class IdentityManager: # Merge with existing metadata if it exists existing_metadata = {} if os.path.exists(metadata_path): - with contextlib.suppress(Exception): - with open(metadata_path) as f: - existing_metadata = json.load(f) + with contextlib.suppress(Exception), open(metadata_path) as f: + existing_metadata = json.load(f) existing_metadata.update(metadata) diff --git a/tests/backend/conftest.py b/tests/backend/conftest.py index ed0cb7a..9fe7e1d 100644 --- a/tests/backend/conftest.py +++ b/tests/backend/conftest.py @@ -19,7 +19,6 @@ def global_mocks(): return_value=None, ), patch("meshchatx.meshchat.generate_ssl_certificate", return_value=None), - patch("threading.Thread"), patch("asyncio.sleep", side_effect=lambda *args, **kwargs: asyncio.sleep(0)), ): # Mock run_async to properly close coroutines diff --git a/tests/backend/test_concurrency_stress.py b/tests/backend/test_concurrency_stress.py new file mode 100644 index 0000000..4026325 --- /dev/null +++ b/tests/backend/test_concurrency_stress.py @@ -0,0 +1,159 @@ +import os +import secrets +import shutil +import tempfile +import threading +import unittest +import time +from meshchatx.src.backend.database import Database +from meshchatx.src.backend.database.provider import DatabaseProvider +from meshchatx.src.backend.identity_manager import IdentityManager + + +class TestConcurrencyStress(unittest.TestCase): + def setUp(self): + # Reset DatabaseProvider singleton for clean state + DatabaseProvider._instance = None + + self.test_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.test_dir, "stress.db") + self.db = Database(self.db_path) + self.db.initialize() + self.stop_threads = False + self.errors = [] + + def tearDown(self): + self.stop_threads = True + self.db.close_all() + # Reset again + DatabaseProvider._instance = None + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + def db_writer_worker(self, worker_id): + """Spams the message table with inserts and updates.""" + try: + from meshchatx.src.backend.database.messages import MessageDAO + + provider = DatabaseProvider.get_instance(self.db_path) + dao = MessageDAO(provider) + peer_hash = secrets.token_hex(16) + count = 0 + while not self.stop_threads and count < 50: + msg = { + "hash": secrets.token_hex(16), + "source_hash": peer_hash, + "destination_hash": "my_hash", + "peer_hash": peer_hash, + "state": "delivered", + "progress": 1.0, + "is_incoming": 1, + "method": "direct", + "delivery_attempts": 1, + "title": f"Stress Msg {worker_id}-{count}", + "content": "A" * 128, + "fields": "{}", + "timestamp": time.time(), + "rssi": -50, + "snr": 5.0, + "quality": 3, + "is_spam": 0, + } + with provider: + dao.upsert_lxmf_message(msg) + count += 1 + time.sleep(0.001) + except Exception as e: + self.errors.append(f"Writer Worker {worker_id} ERROR: {e}") + + def db_reader_worker(self, worker_id): + """Spams the message table with reads and searches.""" + try: + from meshchatx.src.backend.database.messages import MessageDAO + from meshchatx.src.backend.database.announces import AnnounceDAO + + provider = DatabaseProvider.get_instance(self.db_path) + msg_dao = MessageDAO(provider) + ann_dao = AnnounceDAO(provider) + + count = 0 + while not self.stop_threads and count < 50: + # Perform various reads + msg_dao.get_conversations() + ann_dao.get_filtered_announces(limit=10) + count += 1 + time.sleep(0.001) + except Exception as e: + self.errors.append(f"Reader Worker {worker_id} ERROR: {e}") + + def test_database_concurrency(self): + """Launches multiple reader and writer threads to check for lock contention.""" + writers = [ + threading.Thread(target=self.db_writer_worker, args=(i,)) for i in range(5) + ] + readers = [ + threading.Thread(target=self.db_reader_worker, args=(i,)) for i in range(5) + ] + + for t in writers + readers: + t.start() + + for t in writers + readers: + t.join() + + # Assert no errors occurred in threads + if self.errors: + self.fail(f"Errors occurred in threads: \n" + "\n".join(self.errors)) + + # Check if we ended up with the expected number of messages + total = self.db.provider.fetchone( + "SELECT COUNT(*) as count FROM lxmf_messages" + )["count"] + self.assertEqual( + total, 5 * 50, "Total messages inserted doesn't match expected count" + ) + print(f"Stress test completed. Total messages inserted: {total}") + + def test_identity_and_db_collision(self): + """Tests potential collisions between IdentityManager and Database access.""" + manager = IdentityManager(self.test_dir) + + def identity_worker(): + try: + for i in range(20): + if self.stop_threads: + break + manager.create_identity(f"Stress ID {i}") + manager.list_identities() + time.sleep(0.01) + except Exception as e: + self.errors.append(f"Identity Worker ERROR: {e}") + + id_thread = threading.Thread(target=identity_worker) + db_thread = threading.Thread( + target=self.db_writer_worker, args=("id_collision",) + ) + + id_thread.start() + db_thread.start() + + id_thread.join() + db_thread.join() + + # Assert no errors occurred + if self.errors: + self.fail(f"Errors occurred in threads: \n" + "\n".join(self.errors)) + + identities = manager.list_identities() + self.assertEqual(len(identities), 20, "Should have created 20 identities") + + total_messages = self.db.provider.fetchone( + "SELECT COUNT(*) as count FROM lxmf_messages" + )["count"] + self.assertEqual( + total_messages, 50, "Should have inserted 50 messages during collision test" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/backend/test_database_evolution.py b/tests/backend/test_database_evolution.py new file mode 100644 index 0000000..4e72ad1 --- /dev/null +++ b/tests/backend/test_database_evolution.py @@ -0,0 +1,158 @@ +import os +import shutil +import sqlite3 +import tempfile +import unittest +from meshchatx.src.backend.database import Database +from meshchatx.src.backend.database.provider import DatabaseProvider +from meshchatx.src.backend.database.legacy_migrator import LegacyMigrator + + +class TestDatabaseMigration(unittest.TestCase): + def setUp(self): + DatabaseProvider._instance = None + self.test_dir = tempfile.mkdtemp() + # Legacy migrator expects a specific structure: reticulum_config_dir/identities/identity_hash_hex/database.db + self.identity_hash = "deadbeef" + self.legacy_config_dir = os.path.join(self.test_dir, "legacy_config") + self.legacy_db_subdir = os.path.join( + self.legacy_config_dir, "identities", self.identity_hash + ) + os.makedirs(self.legacy_db_subdir, exist_ok=True) + self.legacy_db_path = os.path.join(self.legacy_db_subdir, "database.db") + + # Create legacy database with 1.x/2.x schema + self.create_legacy_db(self.legacy_db_path) + + # Current database + self.current_db_path = os.path.join(self.test_dir, "current.db") + self.db = Database(self.current_db_path) + self.db.initialize() + + def tearDown(self): + self.db.close_all() + shutil.rmtree(self.test_dir) + + def create_legacy_db(self, path): + conn = sqlite3.connect(path) + cursor = conn.cursor() + + # Based on liamcottle/reticulum-meshchat/database.py + cursor.execute(""" + CREATE TABLE config ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + key TEXT UNIQUE, + value TEXT, + created_at DATETIME, + updated_at DATETIME + ) + """) + + cursor.execute(""" + CREATE TABLE announces ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + destination_hash TEXT UNIQUE, + aspect TEXT, + identity_hash TEXT, + identity_public_key TEXT, + app_data TEXT, + rssi INTEGER, + snr REAL, + quality REAL, + created_at DATETIME, + updated_at DATETIME + ) + """) + + cursor.execute(""" + CREATE TABLE lxmf_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + hash TEXT UNIQUE, + source_hash TEXT, + destination_hash TEXT, + state TEXT, + progress REAL, + is_incoming INTEGER, + method TEXT, + delivery_attempts INTEGER, + next_delivery_attempt_at REAL, + title TEXT, + content TEXT, + fields TEXT, + timestamp REAL, + rssi INTEGER, + snr REAL, + quality REAL, + created_at DATETIME, + updated_at DATETIME + ) + """) + + # Insert some legacy data + cursor.execute( + "INSERT INTO config (key, value) VALUES (?, ?)", + ("legacy_key", "legacy_value"), + ) + cursor.execute( + "INSERT INTO announces (destination_hash, aspect, identity_hash) VALUES (?, ?, ?)", + ("dest1", "lxmf.delivery", "id1"), + ) + cursor.execute( + "INSERT INTO lxmf_messages (hash, source_hash, destination_hash, title, content, fields, is_incoming, state, progress, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + "msg1", + "src1", + "dest1", + "Old Title", + "Old Content", + "{}", + 1, + "delivered", + 1.0, + 123456789.0, + ), + ) + + conn.commit() + conn.close() + + def test_migration_evolution(self): + migrator = LegacyMigrator( + self.db.provider, self.legacy_config_dir, self.identity_hash + ) + + # Check if should migrate + self.assertTrue( + migrator.should_migrate(), "Should detect legacy database for migration" + ) + + # Perform migration + success = migrator.migrate() + self.assertTrue(success, "Migration should complete successfully") + + # Verify data in current database + config_rows = self.db.provider.fetchall("SELECT * FROM config") + print(f"Config rows: {config_rows}") + + config_val = self.db.provider.fetchone( + "SELECT value FROM config WHERE key = ?", ("legacy_key",) + ) + self.assertIsNotNone(config_val, "legacy_key should have been migrated") + self.assertEqual(config_val["value"], "legacy_value") + + ann_count = self.db.provider.fetchone( + "SELECT COUNT(*) as count FROM announces" + )["count"] + self.assertEqual(ann_count, 1) + + msg = self.db.provider.fetchone( + "SELECT * FROM lxmf_messages WHERE hash = ?", ("msg1",) + ) + self.assertIsNotNone(msg) + self.assertEqual(msg["title"], "Old Title") + self.assertEqual(msg["content"], "Old Content") + self.assertEqual(msg["source_hash"], "src1") + + +if __name__ == "__main__": + unittest.main()