mirror of
https://git.quad4.io/RNS-Things/MeshChatX.git
synced 2026-04-25 08:52:15 +00:00
Refactor identity manager metadata loading and improve legacy migrator column handling
- Simplified metadata loading in IdentityManager by combining context managers. - Updated LegacyMigrator to exclude 'id' from common columns during migration to prevent collisions. - Increased SQLite connection timeout in DatabaseProvider for improved reliability. - Removed unnecessary thread patching in test configuration. - Added concurrency stress tests for database operations and identity management. - Introduced database migration tests to validate legacy data handling and migration success.
This commit is contained in:
@@ -116,9 +116,12 @@ class LegacyMigrator:
|
||||
)
|
||||
]
|
||||
|
||||
# Find common columns
|
||||
# Find common columns, but exclude 'id' to avoid collisions during migration
|
||||
# as new databases will have their own autoincrement IDs.
|
||||
common_columns = [
|
||||
col for col in legacy_columns if col in current_columns
|
||||
col
|
||||
for col in legacy_columns
|
||||
if col in current_columns and col.lower() != "id"
|
||||
]
|
||||
|
||||
if common_columns:
|
||||
|
||||
@@ -49,6 +49,7 @@ class DatabaseProvider:
|
||||
# isolation_level=None enables autocommit mode, letting us manage transactions manually
|
||||
self._local.connection = sqlite3.connect(
|
||||
self.db_path,
|
||||
timeout=30.0,
|
||||
check_same_thread=False,
|
||||
isolation_level=None,
|
||||
)
|
||||
|
||||
@@ -186,9 +186,8 @@ class IdentityManager:
|
||||
# Merge with existing metadata if it exists
|
||||
existing_metadata = {}
|
||||
if os.path.exists(metadata_path):
|
||||
with contextlib.suppress(Exception):
|
||||
with open(metadata_path) as f:
|
||||
existing_metadata = json.load(f)
|
||||
with contextlib.suppress(Exception), open(metadata_path) as f:
|
||||
existing_metadata = json.load(f)
|
||||
|
||||
existing_metadata.update(metadata)
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ def global_mocks():
|
||||
return_value=None,
|
||||
),
|
||||
patch("meshchatx.meshchat.generate_ssl_certificate", return_value=None),
|
||||
patch("threading.Thread"),
|
||||
patch("asyncio.sleep", side_effect=lambda *args, **kwargs: asyncio.sleep(0)),
|
||||
):
|
||||
# Mock run_async to properly close coroutines
|
||||
|
||||
159
tests/backend/test_concurrency_stress.py
Normal file
159
tests/backend/test_concurrency_stress.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import unittest
|
||||
import time
|
||||
from meshchatx.src.backend.database import Database
|
||||
from meshchatx.src.backend.database.provider import DatabaseProvider
|
||||
from meshchatx.src.backend.identity_manager import IdentityManager
|
||||
|
||||
|
||||
class TestConcurrencyStress(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Reset DatabaseProvider singleton for clean state
|
||||
DatabaseProvider._instance = None
|
||||
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.test_dir, "stress.db")
|
||||
self.db = Database(self.db_path)
|
||||
self.db.initialize()
|
||||
self.stop_threads = False
|
||||
self.errors = []
|
||||
|
||||
def tearDown(self):
|
||||
self.stop_threads = True
|
||||
self.db.close_all()
|
||||
# Reset again
|
||||
DatabaseProvider._instance = None
|
||||
if os.path.exists(self.test_dir):
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def db_writer_worker(self, worker_id):
|
||||
"""Spams the message table with inserts and updates."""
|
||||
try:
|
||||
from meshchatx.src.backend.database.messages import MessageDAO
|
||||
|
||||
provider = DatabaseProvider.get_instance(self.db_path)
|
||||
dao = MessageDAO(provider)
|
||||
peer_hash = secrets.token_hex(16)
|
||||
count = 0
|
||||
while not self.stop_threads and count < 50:
|
||||
msg = {
|
||||
"hash": secrets.token_hex(16),
|
||||
"source_hash": peer_hash,
|
||||
"destination_hash": "my_hash",
|
||||
"peer_hash": peer_hash,
|
||||
"state": "delivered",
|
||||
"progress": 1.0,
|
||||
"is_incoming": 1,
|
||||
"method": "direct",
|
||||
"delivery_attempts": 1,
|
||||
"title": f"Stress Msg {worker_id}-{count}",
|
||||
"content": "A" * 128,
|
||||
"fields": "{}",
|
||||
"timestamp": time.time(),
|
||||
"rssi": -50,
|
||||
"snr": 5.0,
|
||||
"quality": 3,
|
||||
"is_spam": 0,
|
||||
}
|
||||
with provider:
|
||||
dao.upsert_lxmf_message(msg)
|
||||
count += 1
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
self.errors.append(f"Writer Worker {worker_id} ERROR: {e}")
|
||||
|
||||
def db_reader_worker(self, worker_id):
|
||||
"""Spams the message table with reads and searches."""
|
||||
try:
|
||||
from meshchatx.src.backend.database.messages import MessageDAO
|
||||
from meshchatx.src.backend.database.announces import AnnounceDAO
|
||||
|
||||
provider = DatabaseProvider.get_instance(self.db_path)
|
||||
msg_dao = MessageDAO(provider)
|
||||
ann_dao = AnnounceDAO(provider)
|
||||
|
||||
count = 0
|
||||
while not self.stop_threads and count < 50:
|
||||
# Perform various reads
|
||||
msg_dao.get_conversations()
|
||||
ann_dao.get_filtered_announces(limit=10)
|
||||
count += 1
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
self.errors.append(f"Reader Worker {worker_id} ERROR: {e}")
|
||||
|
||||
def test_database_concurrency(self):
|
||||
"""Launches multiple reader and writer threads to check for lock contention."""
|
||||
writers = [
|
||||
threading.Thread(target=self.db_writer_worker, args=(i,)) for i in range(5)
|
||||
]
|
||||
readers = [
|
||||
threading.Thread(target=self.db_reader_worker, args=(i,)) for i in range(5)
|
||||
]
|
||||
|
||||
for t in writers + readers:
|
||||
t.start()
|
||||
|
||||
for t in writers + readers:
|
||||
t.join()
|
||||
|
||||
# Assert no errors occurred in threads
|
||||
if self.errors:
|
||||
self.fail(f"Errors occurred in threads: \n" + "\n".join(self.errors))
|
||||
|
||||
# Check if we ended up with the expected number of messages
|
||||
total = self.db.provider.fetchone(
|
||||
"SELECT COUNT(*) as count FROM lxmf_messages"
|
||||
)["count"]
|
||||
self.assertEqual(
|
||||
total, 5 * 50, "Total messages inserted doesn't match expected count"
|
||||
)
|
||||
print(f"Stress test completed. Total messages inserted: {total}")
|
||||
|
||||
def test_identity_and_db_collision(self):
|
||||
"""Tests potential collisions between IdentityManager and Database access."""
|
||||
manager = IdentityManager(self.test_dir)
|
||||
|
||||
def identity_worker():
|
||||
try:
|
||||
for i in range(20):
|
||||
if self.stop_threads:
|
||||
break
|
||||
manager.create_identity(f"Stress ID {i}")
|
||||
manager.list_identities()
|
||||
time.sleep(0.01)
|
||||
except Exception as e:
|
||||
self.errors.append(f"Identity Worker ERROR: {e}")
|
||||
|
||||
id_thread = threading.Thread(target=identity_worker)
|
||||
db_thread = threading.Thread(
|
||||
target=self.db_writer_worker, args=("id_collision",)
|
||||
)
|
||||
|
||||
id_thread.start()
|
||||
db_thread.start()
|
||||
|
||||
id_thread.join()
|
||||
db_thread.join()
|
||||
|
||||
# Assert no errors occurred
|
||||
if self.errors:
|
||||
self.fail(f"Errors occurred in threads: \n" + "\n".join(self.errors))
|
||||
|
||||
identities = manager.list_identities()
|
||||
self.assertEqual(len(identities), 20, "Should have created 20 identities")
|
||||
|
||||
total_messages = self.db.provider.fetchone(
|
||||
"SELECT COUNT(*) as count FROM lxmf_messages"
|
||||
)["count"]
|
||||
self.assertEqual(
|
||||
total_messages, 50, "Should have inserted 50 messages during collision test"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
158
tests/backend/test_database_evolution.py
Normal file
158
tests/backend/test_database_evolution.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import os
|
||||
import shutil
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import unittest
|
||||
from meshchatx.src.backend.database import Database
|
||||
from meshchatx.src.backend.database.provider import DatabaseProvider
|
||||
from meshchatx.src.backend.database.legacy_migrator import LegacyMigrator
|
||||
|
||||
|
||||
class TestDatabaseMigration(unittest.TestCase):
|
||||
def setUp(self):
|
||||
DatabaseProvider._instance = None
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
# Legacy migrator expects a specific structure: reticulum_config_dir/identities/identity_hash_hex/database.db
|
||||
self.identity_hash = "deadbeef"
|
||||
self.legacy_config_dir = os.path.join(self.test_dir, "legacy_config")
|
||||
self.legacy_db_subdir = os.path.join(
|
||||
self.legacy_config_dir, "identities", self.identity_hash
|
||||
)
|
||||
os.makedirs(self.legacy_db_subdir, exist_ok=True)
|
||||
self.legacy_db_path = os.path.join(self.legacy_db_subdir, "database.db")
|
||||
|
||||
# Create legacy database with 1.x/2.x schema
|
||||
self.create_legacy_db(self.legacy_db_path)
|
||||
|
||||
# Current database
|
||||
self.current_db_path = os.path.join(self.test_dir, "current.db")
|
||||
self.db = Database(self.current_db_path)
|
||||
self.db.initialize()
|
||||
|
||||
def tearDown(self):
|
||||
self.db.close_all()
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def create_legacy_db(self, path):
|
||||
conn = sqlite3.connect(path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Based on liamcottle/reticulum-meshchat/database.py
|
||||
cursor.execute("""
|
||||
CREATE TABLE config (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
key TEXT UNIQUE,
|
||||
value TEXT,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE announces (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
destination_hash TEXT UNIQUE,
|
||||
aspect TEXT,
|
||||
identity_hash TEXT,
|
||||
identity_public_key TEXT,
|
||||
app_data TEXT,
|
||||
rssi INTEGER,
|
||||
snr REAL,
|
||||
quality REAL,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE lxmf_messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
hash TEXT UNIQUE,
|
||||
source_hash TEXT,
|
||||
destination_hash TEXT,
|
||||
state TEXT,
|
||||
progress REAL,
|
||||
is_incoming INTEGER,
|
||||
method TEXT,
|
||||
delivery_attempts INTEGER,
|
||||
next_delivery_attempt_at REAL,
|
||||
title TEXT,
|
||||
content TEXT,
|
||||
fields TEXT,
|
||||
timestamp REAL,
|
||||
rssi INTEGER,
|
||||
snr REAL,
|
||||
quality REAL,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME
|
||||
)
|
||||
""")
|
||||
|
||||
# Insert some legacy data
|
||||
cursor.execute(
|
||||
"INSERT INTO config (key, value) VALUES (?, ?)",
|
||||
("legacy_key", "legacy_value"),
|
||||
)
|
||||
cursor.execute(
|
||||
"INSERT INTO announces (destination_hash, aspect, identity_hash) VALUES (?, ?, ?)",
|
||||
("dest1", "lxmf.delivery", "id1"),
|
||||
)
|
||||
cursor.execute(
|
||||
"INSERT INTO lxmf_messages (hash, source_hash, destination_hash, title, content, fields, is_incoming, state, progress, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
"msg1",
|
||||
"src1",
|
||||
"dest1",
|
||||
"Old Title",
|
||||
"Old Content",
|
||||
"{}",
|
||||
1,
|
||||
"delivered",
|
||||
1.0,
|
||||
123456789.0,
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def test_migration_evolution(self):
|
||||
migrator = LegacyMigrator(
|
||||
self.db.provider, self.legacy_config_dir, self.identity_hash
|
||||
)
|
||||
|
||||
# Check if should migrate
|
||||
self.assertTrue(
|
||||
migrator.should_migrate(), "Should detect legacy database for migration"
|
||||
)
|
||||
|
||||
# Perform migration
|
||||
success = migrator.migrate()
|
||||
self.assertTrue(success, "Migration should complete successfully")
|
||||
|
||||
# Verify data in current database
|
||||
config_rows = self.db.provider.fetchall("SELECT * FROM config")
|
||||
print(f"Config rows: {config_rows}")
|
||||
|
||||
config_val = self.db.provider.fetchone(
|
||||
"SELECT value FROM config WHERE key = ?", ("legacy_key",)
|
||||
)
|
||||
self.assertIsNotNone(config_val, "legacy_key should have been migrated")
|
||||
self.assertEqual(config_val["value"], "legacy_value")
|
||||
|
||||
ann_count = self.db.provider.fetchone(
|
||||
"SELECT COUNT(*) as count FROM announces"
|
||||
)["count"]
|
||||
self.assertEqual(ann_count, 1)
|
||||
|
||||
msg = self.db.provider.fetchone(
|
||||
"SELECT * FROM lxmf_messages WHERE hash = ?", ("msg1",)
|
||||
)
|
||||
self.assertIsNotNone(msg)
|
||||
self.assertEqual(msg["title"], "Old Title")
|
||||
self.assertEqual(msg["content"], "Old Content")
|
||||
self.assertEqual(msg["source_hash"], "src1")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user