diff --git a/meshchatx/src/backend/database/__init__.py b/meshchatx/src/backend/database/__init__.py index 88e8681..c5c3db9 100644 --- a/meshchatx/src/backend/database/__init__.py +++ b/meshchatx/src/backend/database/__init__.py @@ -43,6 +43,7 @@ class Database: self.debug_logs = DebugLogsDAO(self.provider) def initialize(self): + self._tune_sqlite_pragmas() self.schema.initialize() def migrate_from_legacy(self, reticulum_config_dir, identity_hash_hex): @@ -60,9 +61,12 @@ class Database: def _tune_sqlite_pragmas(self): try: + self.execute_sql("PRAGMA journal_mode=WAL") + self.execute_sql("PRAGMA synchronous=NORMAL") self.execute_sql("PRAGMA wal_autocheckpoint=1000") self.execute_sql("PRAGMA temp_store=MEMORY") - self.execute_sql("PRAGMA journal_mode=WAL") + self.execute_sql("PRAGMA cache_size=-8000") # 8 MB + self.execute_sql("PRAGMA mmap_size=67108864") # 64 MB except Exception as exc: print(f"SQLite pragma setup failed: {exc}") diff --git a/meshchatx/src/backend/database/messages.py b/meshchatx/src/backend/database/messages.py index 2481ba9..4548397 100644 --- a/meshchatx/src/backend/database/messages.py +++ b/meshchatx/src/backend/database/messages.py @@ -135,17 +135,18 @@ class MessageDAO: if not destination_hashes: return now = datetime.now(UTC).isoformat() - for destination_hash in destination_hashes: - self.provider.execute( - """ - INSERT INTO lxmf_conversation_read_state (destination_hash, last_read_at, created_at, updated_at) - VALUES (?, ?, ?, ?) - ON CONFLICT(destination_hash) DO UPDATE SET - last_read_at = EXCLUDED.last_read_at, - updated_at = EXCLUDED.updated_at - """, - (destination_hash, now, now, now), - ) + with self.provider: + for destination_hash in destination_hashes: + self.provider.execute( + """ + INSERT INTO lxmf_conversation_read_state (destination_hash, last_read_at, created_at, updated_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(destination_hash) DO UPDATE SET + last_read_at = EXCLUDED.last_read_at, + updated_at = EXCLUDED.updated_at + """, + (destination_hash, now, now, now), + ) def is_conversation_unread(self, destination_hash): row = self.provider.fetchone( @@ -310,17 +311,18 @@ class MessageDAO: def mark_all_notifications_as_viewed(self, destination_hashes=None): now = datetime.now(UTC).isoformat() if destination_hashes: - for destination_hash in destination_hashes: - self.provider.execute( - """ - INSERT INTO notification_viewed_state (destination_hash, last_viewed_at, created_at, updated_at) - VALUES (?, ?, ?, ?) - ON CONFLICT(destination_hash) DO UPDATE SET - last_viewed_at = EXCLUDED.last_viewed_at, - updated_at = EXCLUDED.updated_at - """, - (destination_hash, now, now, now), - ) + with self.provider: + for destination_hash in destination_hashes: + self.provider.execute( + """ + INSERT INTO notification_viewed_state (destination_hash, last_viewed_at, created_at, updated_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(destination_hash) DO UPDATE SET + last_viewed_at = EXCLUDED.last_viewed_at, + updated_at = EXCLUDED.updated_at + """, + (destination_hash, now, now, now), + ) else: # mark all conversations as viewed self.provider.execute( @@ -397,8 +399,9 @@ class MessageDAO: ) def move_conversations_to_folder(self, peer_hashes, folder_id): - for peer_hash in peer_hashes: - self.move_conversation_to_folder(peer_hash, folder_id) + with self.provider: + for peer_hash in peer_hashes: + self.move_conversation_to_folder(peer_hash, folder_id) def get_all_conversation_folders(self): return self.provider.fetchall("SELECT * FROM lxmf_conversation_folders") diff --git a/meshchatx/src/backend/database/schema.py b/meshchatx/src/backend/database/schema.py index 2a76715..eebca3f 100644 --- a/meshchatx/src/backend/database/schema.py +++ b/meshchatx/src/backend/database/schema.py @@ -13,7 +13,7 @@ def _validate_identifier(name: str, label: str = "identifier") -> str: class DatabaseSchema: - LATEST_VERSION = 38 + LATEST_VERSION = 39 def __init__(self, provider: DatabaseProvider): self.provider = provider @@ -1013,6 +1013,39 @@ class DatabaseSchema: "CREATE INDEX IF NOT EXISTS idx_lxmf_messages_reply_to_hash ON lxmf_messages(reply_to_hash)", ) + if current_version < 39: + # Indexes for contacts JOIN columns (used in message_handler.get_conversations + # and announce_manager.get_filtered_announces OR-based JOINs) + self._safe_execute( + "CREATE INDEX IF NOT EXISTS idx_contacts_lxmf_address ON contacts(lxmf_address)", + ) + self._safe_execute( + "CREATE INDEX IF NOT EXISTS idx_contacts_lxst_address ON contacts(lxst_address)", + ) + # Notifications: filter by is_viewed + self._safe_execute( + "CREATE INDEX IF NOT EXISTS idx_notifications_is_viewed ON notifications(is_viewed)", + ) + # Map drawings: lookup by identity_hash + self._safe_execute( + "CREATE INDEX IF NOT EXISTS idx_map_drawings_identity_hash ON map_drawings(identity_hash)", + ) + self._safe_execute( + "CREATE INDEX IF NOT EXISTS idx_map_drawings_identity_name ON map_drawings(identity_hash, name)", + ) + # Voicemails: filter by is_read + self._safe_execute( + "CREATE INDEX IF NOT EXISTS idx_voicemails_is_read ON voicemails(is_read)", + ) + # Archived pages: ORDER BY created_at for cleanup queries + self._safe_execute( + "CREATE INDEX IF NOT EXISTS idx_archived_pages_created_at ON archived_pages(created_at)", + ) + # Conversation message state+peer: for failed_count subquery + self._safe_execute( + "CREATE INDEX IF NOT EXISTS idx_lxmf_messages_state_peer ON lxmf_messages(state, peer_hash)", + ) + # Update version in config self._safe_execute( """ diff --git a/tests/backend/test_dao_fuzzing.py b/tests/backend/test_dao_fuzzing.py index 10f1318..4271adb 100644 --- a/tests/backend/test_dao_fuzzing.py +++ b/tests/backend/test_dao_fuzzing.py @@ -42,22 +42,24 @@ st_nasty_text = st.text( max_size=300, ) -st_sql_payloads = st.sampled_from([ - "'; DROP TABLE config; --", - "' OR '1'='1", - "\" OR \"1\"=\"1", - "1; SELECT * FROM sqlite_master", - "Robert'); DROP TABLE lxmf_messages;--", - "%", - "%%", - "_", - "\x00", - "' UNION SELECT key, value FROM config --", - "NULL", - "1=1", - "admin'--", - "' AND 1=CONVERT(int,(SELECT TOP 1 table_name FROM information_schema.tables))--", -]) +st_sql_payloads = st.sampled_from( + [ + "'; DROP TABLE config; --", + "' OR '1'='1", + '" OR "1"="1', + "1; SELECT * FROM sqlite_master", + "Robert'); DROP TABLE lxmf_messages;--", + "%", + "%%", + "_", + "\x00", + "' UNION SELECT key, value FROM config --", + "NULL", + "1=1", + "admin'--", + "' AND 1=CONVERT(int,(SELECT TOP 1 table_name FROM information_schema.tables))--", + ] +) st_search_term = st.one_of(st_nasty_text, st_sql_payloads) @@ -68,6 +70,7 @@ st_hex_hash = st.from_regex(r"[0-9a-f]{16,64}", fullmatch=True) # Fixture: initialised in-memory database # --------------------------------------------------------------------------- + @pytest.fixture def db(): database = Database(":memory:") @@ -85,8 +88,8 @@ def handler(db): # ContactsDAO # =================================================================== -class TestContactsDAOFuzzing: +class TestContactsDAOFuzzing: @given( name=st_nasty_text, identity_hash=st_hex_hash, @@ -144,8 +147,8 @@ class TestContactsDAOFuzzing: # ConfigDAO # =================================================================== -class TestConfigDAOFuzzing: +class TestConfigDAOFuzzing: @given(key=st_nasty_text.filter(lambda x: len(x) > 0), value=st_nasty_text) @settings( deadline=None, @@ -187,8 +190,8 @@ class TestConfigDAOFuzzing: # MiscDAO — spam keywords, notifications, keyboard shortcuts # =================================================================== -class TestMiscDAOFuzzing: +class TestMiscDAOFuzzing: @given(keyword=st_nasty_text.filter(lambda x: len(x) > 0)) @settings( deadline=None, @@ -221,7 +224,9 @@ class TestMiscDAOFuzzing: suppress_health_check=[HealthCheck.function_scoped_fixture], max_examples=40, ) - def test_add_notification_never_crashes(self, db, ntype, remote_hash, title, content): + def test_add_notification_never_crashes( + self, db, ntype, remote_hash, title, content + ): db.misc.add_notification(ntype, remote_hash, title, content) notifications = db.misc.get_notifications() assert isinstance(notifications, list) @@ -262,8 +267,8 @@ class TestMiscDAOFuzzing: # TelephoneDAO # =================================================================== -class TestTelephoneDAOFuzzing: +class TestTelephoneDAOFuzzing: @given( name=st_nasty_text, identity_hash=st_hex_hash, @@ -314,8 +319,8 @@ class TestTelephoneDAOFuzzing: # VoicemailDAO # =================================================================== -class TestVoicemailDAOFuzzing: +class TestVoicemailDAOFuzzing: @given( name=st_nasty_text, identity_hash=st_hex_hash, @@ -354,8 +359,8 @@ class TestVoicemailDAOFuzzing: # DebugLogsDAO # =================================================================== -class TestDebugLogsDAOFuzzing: +class TestDebugLogsDAOFuzzing: @given( level=st.sampled_from(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), module=st_nasty_text, @@ -400,7 +405,10 @@ class TestDebugLogsDAOFuzzing: def test_count_matches_result_length(self, db, search, level): """get_total_count must be consistent with len(get_logs) when no limit truncation.""" results = db.debug_logs.get_logs( - search=search, level=level, limit=100000, offset=0, + search=search, + level=level, + limit=100000, + offset=0, ) count = db.debug_logs.get_total_count(search=search, level=level) assert count == len(results) @@ -410,8 +418,8 @@ class TestDebugLogsDAOFuzzing: # RingtoneDAO # =================================================================== -class TestRingtoneDAOFuzzing: +class TestRingtoneDAOFuzzing: @given( filename=st_nasty_text.filter(lambda x: len(x) > 0), display_name=st_nasty_text.filter(lambda x: len(x) > 0), @@ -445,13 +453,18 @@ class TestRingtoneDAOFuzzing: # MapDrawingsDAO # =================================================================== -class TestMapDrawingsDAOFuzzing: +class TestMapDrawingsDAOFuzzing: @given( name=st_nasty_text.filter(lambda x: len(x) > 0), data=st.one_of( st_nasty_text, - st.builds(json.dumps, st.dictionaries(st.text(max_size=10), st.text(max_size=50), max_size=10)), + st.builds( + json.dumps, + st.dictionaries( + st.text(max_size=10), st.text(max_size=50), max_size=10 + ), + ), ), ) @settings( @@ -470,8 +483,8 @@ class TestMapDrawingsDAOFuzzing: # MessageDAO — folders # =================================================================== -class TestMessageDAOFoldersFuzzing: +class TestMessageDAOFoldersFuzzing: @given(name=st_nasty_text.filter(lambda x: len(x) > 0)) @settings( deadline=None, @@ -507,8 +520,8 @@ class TestMessageDAOFoldersFuzzing: # MessageHandler — search # =================================================================== -class TestMessageHandlerFuzzing: +class TestMessageHandlerFuzzing: @given(search=st_search_term) @settings( deadline=None, @@ -557,9 +570,14 @@ class TestMessageHandlerFuzzing: suppress_health_check=[HealthCheck.function_scoped_fixture], max_examples=40, ) - def test_get_conversation_messages_never_crashes(self, handler, dest, after_id, before_id): + def test_get_conversation_messages_never_crashes( + self, handler, dest, after_id, before_id + ): results = handler.get_conversation_messages( - "local_hash", dest, after_id=after_id, before_id=before_id, + "local_hash", + dest, + after_id=after_id, + before_id=before_id, ) assert isinstance(results, list) @@ -568,8 +586,8 @@ class TestMessageHandlerFuzzing: # _safe_href — URL sanitization # =================================================================== -class TestSafeHrefFuzzing: +class TestSafeHrefFuzzing: @given(url=st.text(max_size=500)) @settings(deadline=None, max_examples=200) def test_safe_href_never_returns_dangerous_protocol(self, url): @@ -619,14 +637,20 @@ class TestSafeHrefFuzzing: def test_safe_href_blocks_all_generated_dangerous_urls(self, url): assert _safe_href(url) == "#" - @given(scheme=st.text(min_size=1, max_size=20).filter(lambda s: ":" not in s and "/" not in s)) + @given( + scheme=st.text(min_size=1, max_size=20).filter( + lambda s: ":" not in s and "/" not in s + ) + ) @settings(deadline=None, max_examples=60) def test_safe_href_blocks_unknown_schemes(self, scheme): """Any unknown scheme like 'foo:bar' should be blocked.""" url = f"{scheme}:something" result = _safe_href(url) lower = url.strip().lower() - if any(lower.startswith(p) for p in ("https://", "http://", "/", "#", "mailto:")): + if any( + lower.startswith(p) for p in ("https://", "http://", "/", "#", "mailto:") + ): assert result == url else: assert result == "#" @@ -636,31 +660,34 @@ class TestSafeHrefFuzzing: # parse_bool_query_param # =================================================================== -class TestParseBoolQueryParam: +class TestParseBoolQueryParam: @given(value=st.text(max_size=100)) @settings(deadline=None, max_examples=200) def test_never_crashes_and_returns_bool(self, value): result = parse_bool_query_param(value) assert isinstance(result, bool) - @pytest.mark.parametrize("value,expected", [ - ("1", True), - ("true", True), - ("True", True), - ("TRUE", True), - ("yes", True), - ("YES", True), - ("on", True), - ("ON", True), - ("0", False), - ("false", False), - ("no", False), - ("off", False), - ("", False), - ("maybe", False), - (None, False), - ]) + @pytest.mark.parametrize( + "value,expected", + [ + ("1", True), + ("true", True), + ("True", True), + ("TRUE", True), + ("yes", True), + ("YES", True), + ("on", True), + ("ON", True), + ("0", False), + ("false", False), + ("no", False), + ("off", False), + ("", False), + ("maybe", False), + (None, False), + ], + ) def test_known_values(self, value, expected): assert parse_bool_query_param(value) is expected @@ -669,8 +696,8 @@ class TestParseBoolQueryParam: # message_fields_have_attachments # =================================================================== -class TestMessageFieldsHaveAttachments: +class TestMessageFieldsHaveAttachments: @given(data=st.text(max_size=500)) @settings(deadline=None, max_examples=100) def test_never_crashes_on_arbitrary_text(self, data): @@ -689,25 +716,30 @@ class TestMessageFieldsHaveAttachments: result = message_fields_have_attachments(json.dumps(obj)) assert isinstance(result, bool) - @pytest.mark.parametrize("fields_json,expected", [ - (None, False), - ("", False), - ("{}", False), - ("not json", False), - ('{"image": "base64data"}', True), - ('{"audio": "base64data"}', True), - ('{"file_attachments": [{"name": "f.txt"}]}', True), - ('{"file_attachments": []}', False), - ('{"other_field": "value"}', False), - ]) + @pytest.mark.parametrize( + "fields_json,expected", + [ + (None, False), + ("", False), + ("{}", False), + ("not json", False), + ('{"image": "base64data"}', True), + ('{"audio": "base64data"}', True), + ('{"file_attachments": [{"name": "f.txt"}]}', True), + ('{"file_attachments": []}', False), + ('{"other_field": "value"}', False), + ], + ) def test_known_cases(self, fields_json, expected): assert message_fields_have_attachments(fields_json) is expected @given( data=st.recursive( st.one_of(st.text(max_size=20), st.integers(), st.booleans(), st.none()), - lambda children: st.lists(children, max_size=5) - | st.dictionaries(st.text(max_size=10), children, max_size=5), + lambda children: ( + st.lists(children, max_size=5) + | st.dictionaries(st.text(max_size=10), children, max_size=5) + ), max_leaves=30, ) ) diff --git a/tests/backend/test_performance_hotpaths.py b/tests/backend/test_performance_hotpaths.py index e5cfd09..da6c82b 100644 --- a/tests/backend/test_performance_hotpaths.py +++ b/tests/backend/test_performance_hotpaths.py @@ -22,7 +22,6 @@ import tempfile import threading import time import unittest -from unittest.mock import MagicMock from meshchatx.src.backend.announce_manager import AnnounceManager from meshchatx.src.backend.database import Database @@ -33,6 +32,7 @@ from meshchatx.src.backend.message_handler import MessageHandler # Helpers # --------------------------------------------------------------------------- + def percentile(data, pct): """Return the pct-th percentile of sorted data.""" if not data: @@ -108,8 +108,8 @@ def make_announce(i): # Test class # --------------------------------------------------------------------------- -class TestPerformanceHotPaths(unittest.TestCase): +class TestPerformanceHotPaths(unittest.TestCase): NUM_MESSAGES = 10_000 NUM_PEERS = 200 NUM_ANNOUNCES = 5_000 @@ -134,7 +134,7 @@ class TestPerformanceHotPaths(unittest.TestCase): @classmethod def _seed_data(cls): - print(f"\n--- Seeding test data ---") + print("\n--- Seeding test data ---") # Peers cls.peer_hashes = [secrets.token_hex(16) for _ in range(cls.NUM_PEERS)] @@ -286,7 +286,9 @@ class TestPerformanceHotPaths(unittest.TestCase): dest = secrets.token_hex(16) _, ms = timed_call( self.db.announces.upsert_favourite, - dest, f"Bench Fav {i}", "lxmf.delivery", + dest, + f"Bench Fav {i}", + "lxmf.delivery", ) durations.append(ms) @@ -316,7 +318,13 @@ class TestPerformanceHotPaths(unittest.TestCase): def test_conversations_search_latency(self): """Search conversations — LIKE across titles, content, peer hashes.""" print("\n[Conversations] Search:") - terms = ["Message title 5", "Content body", "abc", "zzz_nope", self.heavy_peer[:8]] + terms = [ + "Message title 5", + "Content body", + "abc", + "zzz_nope", + self.heavy_peer[:8], + ] durations = [] for term in terms: _, ms = timed_call( @@ -450,7 +458,9 @@ class TestPerformanceHotPaths(unittest.TestCase): with lock: all_durations.extend(thread_durations) - threads = [threading.Thread(target=writer, args=(t,)) for t in range(num_threads)] + threads = [ + threading.Thread(target=writer, args=(t,)) for t in range(num_threads) + ] t0 = time.perf_counter() for t in threads: t.start() @@ -460,8 +470,10 @@ class TestPerformanceHotPaths(unittest.TestCase): total_ops = num_threads * msgs_per_thread throughput = total_ops / (wall_ms / 1000) - print(f" Wall time: {wall_ms:.0f}ms for {total_ops} inserts ({throughput:.0f} ops/s)") - stats = latency_report("concurrent_write", all_durations) + print( + f" Wall time: {wall_ms:.0f}ms for {total_ops} inserts ({throughput:.0f} ops/s)" + ) + latency_report("concurrent_write", all_durations) self.assertEqual(len(errors), 0, f"Writer errors: {errors[:5]}") self.assertGreater(throughput, 100, "Concurrent write throughput < 100 ops/s") @@ -487,7 +499,9 @@ class TestPerformanceHotPaths(unittest.TestCase): with lock: all_durations.extend(thread_durations) - threads = [threading.Thread(target=writer, args=(t,)) for t in range(num_threads)] + threads = [ + threading.Thread(target=writer, args=(t,)) for t in range(num_threads) + ] t0 = time.perf_counter() for t in threads: t.start() @@ -497,8 +511,10 @@ class TestPerformanceHotPaths(unittest.TestCase): total_ops = num_threads * announces_per_thread throughput = total_ops / (wall_ms / 1000) - print(f" Wall time: {wall_ms:.0f}ms for {total_ops} upserts ({throughput:.0f} ops/s)") - stats = latency_report("concurrent_announce_write", all_durations) + print( + f" Wall time: {wall_ms:.0f}ms for {total_ops} upserts ({throughput:.0f} ops/s)" + ) + latency_report("concurrent_announce_write", all_durations) self.assertEqual(len(errors), 0, f"Writer errors: {errors[:5]}") self.assertGreater(throughput, 100, "Concurrent announce write < 100 ops/s") @@ -543,8 +559,12 @@ class TestPerformanceHotPaths(unittest.TestCase): with lock: read_durations.extend(local_durs) - writers = [threading.Thread(target=writer, args=(t,)) for t in range(num_writers)] - readers = [threading.Thread(target=reader, args=(t,)) for t in range(num_readers)] + writers = [ + threading.Thread(target=writer, args=(t,)) for t in range(num_writers) + ] + readers = [ + threading.Thread(target=reader, args=(t,)) for t in range(num_readers) + ] t0 = time.perf_counter() for t in writers + readers: @@ -593,6 +613,154 @@ class TestPerformanceHotPaths(unittest.TestCase): print(f" Contacts LIKE search ({self.NUM_CONTACTS} rows): {ms_con:.2f}ms") self.assertLess(ms_con, 50, "Contacts LIKE search > 50ms") + # =================================================================== + # N+1 BATCH OPERATIONS — transaction wrapping regression tests + # =================================================================== + + def test_mark_conversations_as_read_batch(self): + """mark_conversations_as_read should be fast for large batches (transaction-wrapped).""" + print("\n[Batch] mark_conversations_as_read:") + hashes = [secrets.token_hex(16) for _ in range(200)] + durations = [] + for _ in range(5): + _, ms = timed_call(self.db.messages.mark_conversations_as_read, hashes) + durations.append(ms) + + stats = latency_report("mark_read_200", durations) + self.assertLess(stats["p95"], 50, "mark_conversations_as_read(200) p95 > 50ms") + + def test_mark_all_notifications_as_viewed_batch(self): + """mark_all_notifications_as_viewed should be fast for large batches.""" + print("\n[Batch] mark_all_notifications_as_viewed:") + hashes = [secrets.token_hex(16) for _ in range(200)] + durations = [] + for _ in range(5): + _, ms = timed_call( + self.db.messages.mark_all_notifications_as_viewed, hashes + ) + durations.append(ms) + + stats = latency_report("mark_viewed_200", durations) + self.assertLess( + stats["p95"], 50, "mark_all_notifications_as_viewed(200) p95 > 50ms" + ) + + def test_move_conversations_to_folder_batch(self): + """move_conversations_to_folder should be fast for large batches.""" + print("\n[Batch] move_conversations_to_folder:") + self.db.messages.create_folder("perf_test_folder") + folders = self.db.messages.get_all_folders() + folder_id = folders[0]["id"] + + hashes = [secrets.token_hex(16) for _ in range(200)] + durations = [] + for _ in range(5): + _, ms = timed_call( + self.db.messages.move_conversations_to_folder, hashes, folder_id + ) + durations.append(ms) + + stats = latency_report("move_folder_200", durations) + self.assertLess( + stats["p95"], 50, "move_conversations_to_folder(200) p95 > 50ms" + ) + + # =================================================================== + # INDEX VERIFICATION — confirm new indexes are used + # =================================================================== + + def test_indexes_exist(self): + """Verify critical indexes exist in the schema.""" + print("\n[Indexes] Checking critical indexes exist:") + rows = self.db.provider.fetchall( + "SELECT name FROM sqlite_master WHERE type='index'" + ) + index_names = {r["name"] for r in rows} + + expected = [ + "idx_contacts_lxmf_address", + "idx_contacts_lxst_address", + "idx_notifications_is_viewed", + "idx_map_drawings_identity_hash", + "idx_map_drawings_identity_name", + "idx_voicemails_is_read", + "idx_archived_pages_created_at", + "idx_lxmf_messages_state_peer", + "idx_lxmf_messages_peer_hash", + "idx_lxmf_messages_peer_ts", + "idx_announces_updated_at", + "idx_announces_aspect", + ] + for idx in expected: + self.assertIn(idx, index_names, f"Missing index: {idx}") + print(f" {idx}: OK") + + def test_pragmas_applied(self): + """Verify performance PRAGMAs are active.""" + print("\n[PRAGMAs] Checking applied PRAGMAs:") + journal = self.db._get_pragma_value("journal_mode") + print(f" journal_mode: {journal}") + self.assertEqual(journal, "wal") + + sync = self.db._get_pragma_value("synchronous") + print(f" synchronous: {sync}") + self.assertEqual(sync, 1) # NORMAL = 1 + + temp_store = self.db._get_pragma_value("temp_store") + print(f" temp_store: {temp_store}") + self.assertEqual(temp_store, 2) # MEMORY = 2 + + cache_size = self.db._get_pragma_value("cache_size") + print(f" cache_size: {cache_size}") + self.assertLessEqual(cache_size, -8000) + + # =================================================================== + # QUERY PLAN CHECKS — confirm indexes are actually used + # =================================================================== + + def test_query_plan_messages_by_peer(self): + """The most common message query should use peer_hash index.""" + print("\n[Query Plan] Messages by peer_hash:") + rows = self.db.provider.fetchall( + "EXPLAIN QUERY PLAN SELECT * FROM lxmf_messages WHERE peer_hash = ? ORDER BY id DESC LIMIT 50", + ("test",), + ) + plan = " ".join(str(r["detail"]) for r in rows) + print(f" {plan}") + self.assertIn("idx_lxmf_messages_peer_hash", plan.lower()) + + def test_query_plan_announces_by_aspect(self): + """Announce filtering by aspect should use the aspect index.""" + print("\n[Query Plan] Announces by aspect:") + rows = self.db.provider.fetchall( + "EXPLAIN QUERY PLAN SELECT * FROM announces WHERE aspect = ? ORDER BY updated_at DESC LIMIT 50", + ("lxmf.delivery",), + ) + plan = " ".join(str(r["detail"]) for r in rows) + print(f" {plan}") + self.assertIn("idx_announces_aspect", plan.lower()) + + def test_query_plan_failed_messages_state_peer(self): + """The failed_count subquery should use the state+peer composite index.""" + print("\n[Query Plan] Failed messages (state, peer_hash):") + rows = self.db.provider.fetchall( + "EXPLAIN QUERY PLAN SELECT COUNT(*) FROM lxmf_messages WHERE state = 'failed' AND peer_hash = ?", + ("test",), + ) + plan = " ".join(str(r["detail"]) for r in rows) + print(f" {plan}") + self.assertIn("idx_lxmf_messages_state_peer", plan.lower()) + + def test_query_plan_notifications_unread(self): + """Notification unread filter should use the is_viewed index.""" + print("\n[Query Plan] Notifications unread:") + rows = self.db.provider.fetchall( + "EXPLAIN QUERY PLAN SELECT * FROM notifications WHERE is_viewed = 0 ORDER BY timestamp DESC LIMIT 50", + ) + plan = " ".join(str(r["detail"]) for r in rows) + print(f" {plan}") + self.assertIn("idx_notifications_is_viewed", plan.lower()) + if __name__ == "__main__": unittest.main() diff --git a/tests/backend/test_sql_injection_fixes.py b/tests/backend/test_sql_injection_fixes.py index 1d4e8d4..b1b94d4 100644 --- a/tests/backend/test_sql_injection_fixes.py +++ b/tests/backend/test_sql_injection_fixes.py @@ -9,7 +9,6 @@ Covers: import os import re import sqlite3 -import tempfile import pytest from hypothesis import HealthCheck, given, settings @@ -24,6 +23,7 @@ from meshchatx.src.backend.database.schema import DatabaseSchema, _validate_iden # Fixtures # --------------------------------------------------------------------------- + @pytest.fixture def db_env(tmp_path): """Provide an initialized DatabaseProvider + Schema in a temp directory.""" @@ -52,8 +52,8 @@ def _make_legacy_db(legacy_dir, identity_hash, tables_sql): # 1. ATTACH DATABASE — single-quote escaping # --------------------------------------------------------------------------- -class TestAttachDatabasePathEscaping: +class TestAttachDatabasePathEscaping: def test_path_without_quotes_migrates_normally(self, db_env): provider, _schema, tmp_path = db_env legacy_dir = str(tmp_path / "legacy_normal") @@ -153,8 +153,8 @@ class TestAttachDatabasePathEscaping: # 2. Legacy migrator — malicious column names filtered out # --------------------------------------------------------------------------- -class TestLegacyColumnFiltering: +class TestLegacyColumnFiltering: def test_normal_columns_migrate(self, db_env): """Standard column names pass through the identifier filter.""" provider, _schema, tmp_path = db_env @@ -188,9 +188,7 @@ class TestLegacyColumnFiltering: conn.execute( 'CREATE TABLE config (key TEXT UNIQUE, value TEXT, "key; DROP TABLE config" TEXT)' ) - conn.execute( - "INSERT INTO config (key, value) VALUES ('safe', 'data')" - ) + conn.execute("INSERT INTO config (key, value) VALUES ('safe', 'data')") conn.commit() conn.close() @@ -211,12 +209,8 @@ class TestLegacyColumnFiltering: os.makedirs(identity_dir, exist_ok=True) db_path = os.path.join(identity_dir, "database.db") conn = sqlite3.connect(db_path) - conn.execute( - 'CREATE TABLE config (key TEXT UNIQUE, value TEXT, "evil()" TEXT)' - ) - conn.execute( - "INSERT INTO config (key, value) VALUES ('p1', 'parens')" - ) + conn.execute('CREATE TABLE config (key TEXT UNIQUE, value TEXT, "evil()" TEXT)') + conn.execute("INSERT INTO config (key, value) VALUES ('p1', 'parens')") conn.commit() conn.close() @@ -231,8 +225,8 @@ class TestLegacyColumnFiltering: # 3. _validate_identifier — unit tests # --------------------------------------------------------------------------- -class TestValidateIdentifier: +class TestValidateIdentifier: @pytest.mark.parametrize( "name", [ @@ -273,8 +267,8 @@ class TestValidateIdentifier: # 4. _ensure_column — rejects injection via table/column names # --------------------------------------------------------------------------- -class TestEnsureColumnInjection: +class TestEnsureColumnInjection: def test_ensure_column_rejects_malicious_table_name(self, db_env): _provider, schema, _tmp_path = db_env with pytest.raises(ValueError, match="Invalid SQL table name"): @@ -354,6 +348,7 @@ def test_validate_identifier_rejects_pure_metacharacter_strings(name): # 6. ATTACH path escaping — property-based # --------------------------------------------------------------------------- + @given( path_segment=st.text( alphabet=st.characters(