From 9d7ae2017b7a4f378eaebadf3a24bf066e5d3eb9 Mon Sep 17 00:00:00 2001 From: Sudo-Ivan Date: Fri, 6 Mar 2026 03:08:28 -0600 Subject: [PATCH] Add SQL injection prevention measures in database migration and schema validation - Implemented identifier validation regex in legacy migrator and schema. - Enhanced database path handling to escape single quotes during ATTACH DATABASE. - Added tests for SQL injection scenarios, ensuring robustness against malicious inputs. - Introduced fuzz tests for DAO layers to cover edge cases and improve overall test coverage. --- .../src/backend/database/legacy_migrator.py | 10 +- meshchatx/src/backend/database/schema.py | 28 +- tests/backend/test_dao_fuzzing.py | 717 ++++++++++++++++++ tests/backend/test_performance_hotpaths.py | 598 +++++++++++++++ tests/backend/test_sql_injection_fixes.py | 392 ++++++++++ 5 files changed, 1731 insertions(+), 14 deletions(-) create mode 100644 tests/backend/test_dao_fuzzing.py create mode 100644 tests/backend/test_performance_hotpaths.py create mode 100644 tests/backend/test_sql_injection_fixes.py diff --git a/meshchatx/src/backend/database/legacy_migrator.py b/meshchatx/src/backend/database/legacy_migrator.py index 5c06675..596fa50 100644 --- a/meshchatx/src/backend/database/legacy_migrator.py +++ b/meshchatx/src/backend/database/legacy_migrator.py @@ -1,4 +1,7 @@ import os +import re + +_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") class LegacyMigrator: @@ -73,7 +76,8 @@ class LegacyMigrator: # Attach the legacy database # We use a randomized alias to avoid collisions alias = f"legacy_{os.urandom(4).hex()}" - self.provider.execute(f"ATTACH DATABASE '{legacy_path}' AS {alias}") + safe_path = legacy_path.replace("'", "''") + self.provider.execute(f"ATTACH DATABASE '{safe_path}' AS {alias}") # noqa: S608 # Tables that existed in the legacy Peewee version tables_to_migrate = [ @@ -121,7 +125,9 @@ class LegacyMigrator: common_columns = [ col for col in legacy_columns - if col in current_columns and col.lower() != "id" + if col in current_columns + and col.lower() != "id" + and _IDENTIFIER_RE.match(col) ] if common_columns: diff --git a/meshchatx/src/backend/database/schema.py b/meshchatx/src/backend/database/schema.py index bbe9b75..2a76715 100644 --- a/meshchatx/src/backend/database/schema.py +++ b/meshchatx/src/backend/database/schema.py @@ -1,5 +1,16 @@ +import re + from .provider import DatabaseProvider +_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def _validate_identifier(name: str, label: str = "identifier") -> str: + if not _IDENTIFIER_RE.match(name): + msg = f"Invalid SQL {label}: {name!r}" + raise ValueError(msg) + return name + class DatabaseSchema: LATEST_VERSION = 38 @@ -28,21 +39,18 @@ class DatabaseSchema: def _ensure_column(self, table_name, column_name, column_type): """Add a column to a table if it doesn't exist.""" - # First check if it exists using PRAGMA + _validate_identifier(table_name, "table name") + _validate_identifier(column_name, "column name") + cursor = self.provider.connection.cursor() try: - cursor.execute(f"PRAGMA table_info({table_name})") + cursor.execute(f"PRAGMA table_info({table_name})") # noqa: S608 columns = [row[1] for row in cursor.fetchall()] finally: cursor.close() if column_name not in columns: try: - # SQLite has limitations on ALTER TABLE ADD COLUMN: - # 1. Cannot add UNIQUE or PRIMARY KEY columns - # 2. Cannot add columns with non-constant defaults (like CURRENT_TIMESTAMP) - - # Strip non-constant defaults if present for the ALTER TABLE statement stmt_type = column_type forbidden_defaults = [ "CURRENT_TIMESTAMP", @@ -51,9 +59,6 @@ class DatabaseSchema: ] for forbidden in forbidden_defaults: if f"DEFAULT {forbidden}" in stmt_type.upper(): - # Remove the DEFAULT part for the ALTER statement - import re - stmt_type = re.sub( f"DEFAULT\\s+{forbidden}", "", @@ -61,9 +66,8 @@ class DatabaseSchema: flags=re.IGNORECASE, ).strip() - # Use the connection directly to avoid any middle-ware issues res = self._safe_execute( - f"ALTER TABLE {table_name} ADD COLUMN {column_name} {stmt_type}", + f"ALTER TABLE {table_name} ADD COLUMN {column_name} {stmt_type}", # noqa: S608 ) return res is not None except Exception as e: diff --git a/tests/backend/test_dao_fuzzing.py b/tests/backend/test_dao_fuzzing.py new file mode 100644 index 0000000..10f1318 --- /dev/null +++ b/tests/backend/test_dao_fuzzing.py @@ -0,0 +1,717 @@ +"""Property-based and fuzz tests for DAO layers, MessageHandler, and utility functions. + +Covers gaps identified in the existing test suite: + - ContactsDAO: add, search, update with adversarial strings + - ConfigDAO: get/set/delete round-trip with arbitrary keys and values + - MiscDAO: spam keywords, notifications, keyboard shortcuts + - TelephoneDAO: add + search with adversarial strings + - VoicemailDAO: add + search + - DebugLogsDAO: insert + search + count consistency + - RingtoneDAO: add + update with adversarial display names + - MapDrawingsDAO: upsert + update with arbitrary JSON data + - MessageDAO: folder create/rename + - MessageHandler: search_messages, get_conversations with adversarial search terms + - _safe_href: URL sanitization (XSS vectors) + - parse_bool_query_param: arbitrary strings + - message_fields_have_attachments: arbitrary JSON +""" + +import json + +import pytest +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st + +from meshchatx.src.backend.database import Database +from meshchatx.src.backend.markdown_renderer import _safe_href +from meshchatx.src.backend.meshchat_utils import ( + message_fields_have_attachments, + parse_bool_query_param, +) +from meshchatx.src.backend.message_handler import MessageHandler + + +# --------------------------------------------------------------------------- +# Strategies +# --------------------------------------------------------------------------- + +# Strings that are valid for most text columns but include adversarial chars +st_nasty_text = st.text( + alphabet=st.characters(whitelist_categories=("L", "N", "P", "S", "Z", "C")), + min_size=0, + max_size=300, +) + +st_sql_payloads = st.sampled_from([ + "'; DROP TABLE config; --", + "' OR '1'='1", + "\" OR \"1\"=\"1", + "1; SELECT * FROM sqlite_master", + "Robert'); DROP TABLE lxmf_messages;--", + "%", + "%%", + "_", + "\x00", + "' UNION SELECT key, value FROM config --", + "NULL", + "1=1", + "admin'--", + "' AND 1=CONVERT(int,(SELECT TOP 1 table_name FROM information_schema.tables))--", +]) + +st_search_term = st.one_of(st_nasty_text, st_sql_payloads) + +st_hex_hash = st.from_regex(r"[0-9a-f]{16,64}", fullmatch=True) + + +# --------------------------------------------------------------------------- +# Fixture: initialised in-memory database +# --------------------------------------------------------------------------- + +@pytest.fixture +def db(): + database = Database(":memory:") + database.initialize() + yield database + database.close() + + +@pytest.fixture +def handler(db): + return MessageHandler(db) + + +# =================================================================== +# ContactsDAO +# =================================================================== + +class TestContactsDAOFuzzing: + + @given( + name=st_nasty_text, + identity_hash=st_hex_hash, + lxmf_addr=st.one_of(st.none(), st_hex_hash), + lxst_addr=st.one_of(st.none(), st_hex_hash), + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=60, + ) + def test_add_contact_roundtrip(self, db, name, identity_hash, lxmf_addr, lxst_addr): + db.contacts.add_contact( + name=name, + remote_identity_hash=identity_hash, + lxmf_address=lxmf_addr, + lxst_address=lxst_addr, + ) + row = db.contacts.get_contact_by_identity_hash(identity_hash) + assert row is not None + assert row["name"] == name + + @given(search=st_search_term) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=80, + ) + def test_search_contacts_never_crashes(self, db, search): + results = db.contacts.get_contacts(search=search) + assert isinstance(results, list) + count = db.contacts.get_contacts_count(search=search) + assert isinstance(count, int) + assert count >= 0 + + @given( + name=st_nasty_text, + identity_hash=st_hex_hash, + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=40, + ) + def test_update_contact_never_crashes(self, db, name, identity_hash): + db.contacts.add_contact(name="orig", remote_identity_hash=identity_hash) + row = db.contacts.get_contact_by_identity_hash(identity_hash) + assert row is not None + db.contacts.update_contact(row["id"], name=name) + updated = db.contacts.get_contact(row["id"]) + assert updated["name"] == name + + +# =================================================================== +# ConfigDAO +# =================================================================== + +class TestConfigDAOFuzzing: + + @given(key=st_nasty_text.filter(lambda x: len(x) > 0), value=st_nasty_text) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=80, + ) + def test_config_set_get_roundtrip(self, db, key, value): + db.config.set(key, value) + got = db.config.get(key) + assert got == value + + @given(key=st_nasty_text.filter(lambda x: len(x) > 0)) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=40, + ) + def test_config_delete_never_crashes(self, db, key): + db.config.set(key, "temp") + db.config.delete(key) + assert db.config.get(key) is None + + @given(key=st_sql_payloads) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + ) + def test_config_sql_injection_keys(self, db, key): + db.config.set(key, "injected?") + assert db.config.get(key) == "injected?" + tables = db.provider.fetchall( + "SELECT name FROM sqlite_master WHERE type='table'" + ) + table_names = {r["name"] for r in tables} + assert "config" in table_names + + +# =================================================================== +# MiscDAO — spam keywords, notifications, keyboard shortcuts +# =================================================================== + +class TestMiscDAOFuzzing: + + @given(keyword=st_nasty_text.filter(lambda x: len(x) > 0)) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=50, + ) + def test_spam_keyword_roundtrip(self, db, keyword): + db.misc.add_spam_keyword(keyword) + keywords = db.misc.get_spam_keywords() + assert any(k["keyword"] == keyword for k in keywords) + + @given(title=st_nasty_text, content=st_nasty_text) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=50, + ) + def test_check_spam_never_crashes(self, db, title, content): + result = db.misc.check_spam_keywords(title, content) + assert isinstance(result, bool) + + @given( + ntype=st.sampled_from(["message", "call", "voicemail", "system"]), + remote_hash=st_hex_hash, + title=st_nasty_text, + content=st_nasty_text, + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=40, + ) + def test_add_notification_never_crashes(self, db, ntype, remote_hash, title, content): + db.misc.add_notification(ntype, remote_hash, title, content) + notifications = db.misc.get_notifications() + assert isinstance(notifications, list) + + @given( + action=st_nasty_text.filter(lambda x: len(x) > 0), + keys=st_nasty_text, + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=40, + ) + def test_keyboard_shortcut_roundtrip(self, db, action, keys): + identity = "testhash" + db.misc.upsert_keyboard_shortcut(identity, action, keys) + shortcuts = db.misc.get_keyboard_shortcuts(identity) + assert any(s["action"] == action and s["keys"] == keys for s in shortcuts) + + @given( + query=st_search_term, + dest=st.one_of(st.none(), st_hex_hash), + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=40, + ) + def test_archived_pages_search_never_crashes(self, db, query, dest): + results = db.misc.get_archived_pages_paginated( + destination_hash=dest, + query=query if query else None, + ) + assert isinstance(results, list) + + +# =================================================================== +# TelephoneDAO +# =================================================================== + +class TestTelephoneDAOFuzzing: + + @given( + name=st_nasty_text, + identity_hash=st_hex_hash, + status=st.sampled_from(["answered", "missed", "rejected", "busy"]), + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=40, + ) + def test_add_call_history_never_crashes(self, db, name, identity_hash, status): + import time + + db.telephone.add_call_history( + remote_identity_hash=identity_hash, + remote_identity_name=name, + is_incoming=True, + status=status, + duration_seconds=42, + timestamp=time.time(), + ) + history = db.telephone.get_call_history() + assert isinstance(history, list) + assert len(history) > 0 + + @given(search=st_search_term) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=60, + ) + def test_call_history_search_never_crashes(self, db, search): + results = db.telephone.get_call_history(search=search) + assert isinstance(results, list) + + @given(search=st_search_term) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=60, + ) + def test_call_recordings_search_never_crashes(self, db, search): + results = db.telephone.get_call_recordings(search=search) + assert isinstance(results, list) + + +# =================================================================== +# VoicemailDAO +# =================================================================== + +class TestVoicemailDAOFuzzing: + + @given( + name=st_nasty_text, + identity_hash=st_hex_hash, + filename=st_nasty_text, + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=40, + ) + def test_add_voicemail_never_crashes(self, db, name, identity_hash, filename): + import time + + db.voicemails.add_voicemail( + remote_identity_hash=identity_hash, + remote_identity_name=name, + filename=filename, + duration_seconds=10, + timestamp=time.time(), + ) + vms = db.voicemails.get_voicemails() + assert isinstance(vms, list) + + @given(search=st_search_term) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=60, + ) + def test_voicemail_search_never_crashes(self, db, search): + results = db.voicemails.get_voicemails(search=search) + assert isinstance(results, list) + + +# =================================================================== +# DebugLogsDAO +# =================================================================== + +class TestDebugLogsDAOFuzzing: + + @given( + level=st.sampled_from(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), + module=st_nasty_text, + message=st_nasty_text, + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=50, + ) + def test_insert_log_never_crashes(self, db, level, module, message): + db.debug_logs.insert_log(level, module, message) + total = db.debug_logs.get_total_count() + assert total > 0 + + @given( + search=st_search_term, + level=st.one_of(st.none(), st.sampled_from(["DEBUG", "INFO", "ERROR"])), + module=st.one_of(st.none(), st_nasty_text), + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=60, + ) + def test_search_logs_never_crashes(self, db, search, level, module): + results = db.debug_logs.get_logs(search=search, level=level, module=module) + assert isinstance(results, list) + count = db.debug_logs.get_total_count(search=search, level=level, module=module) + assert isinstance(count, int) + assert count >= 0 + + @given( + search=st_search_term, + level=st.one_of(st.none(), st.sampled_from(["DEBUG", "INFO", "ERROR"])), + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=40, + ) + def test_count_matches_result_length(self, db, search, level): + """get_total_count must be consistent with len(get_logs) when no limit truncation.""" + results = db.debug_logs.get_logs( + search=search, level=level, limit=100000, offset=0, + ) + count = db.debug_logs.get_total_count(search=search, level=level) + assert count == len(results) + + +# =================================================================== +# RingtoneDAO +# =================================================================== + +class TestRingtoneDAOFuzzing: + + @given( + filename=st_nasty_text.filter(lambda x: len(x) > 0), + display_name=st_nasty_text.filter(lambda x: len(x) > 0), + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=40, + ) + def test_add_ringtone_roundtrip(self, db, filename, display_name): + rid = db.ringtones.add(filename, f"store_{filename[:20]}", display_name) + assert rid is not None + row = db.ringtones.get_by_id(rid) + assert row is not None + assert row["display_name"] == display_name + + @given(new_name=st_nasty_text.filter(lambda x: len(x) > 0)) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=30, + ) + def test_update_display_name(self, db, new_name): + rid = db.ringtones.add("test.ogg", "store_test.ogg", "Original") + db.ringtones.update(rid, display_name=new_name) + row = db.ringtones.get_by_id(rid) + assert row["display_name"] == new_name + + +# =================================================================== +# MapDrawingsDAO +# =================================================================== + +class TestMapDrawingsDAOFuzzing: + + @given( + name=st_nasty_text.filter(lambda x: len(x) > 0), + data=st.one_of( + st_nasty_text, + st.builds(json.dumps, st.dictionaries(st.text(max_size=10), st.text(max_size=50), max_size=10)), + ), + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=40, + ) + def test_upsert_drawing_roundtrip(self, db, name, data): + identity = "deadbeef01234567" + db.map_drawings.upsert_drawing(identity, name, data) + drawings = db.map_drawings.get_drawings(identity) + assert any(d["name"] == name for d in drawings) + + +# =================================================================== +# MessageDAO — folders +# =================================================================== + +class TestMessageDAOFoldersFuzzing: + + @given(name=st_nasty_text.filter(lambda x: len(x) > 0)) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=40, + ) + def test_create_folder_roundtrip(self, db, name): + db.messages.create_folder(name) + folders = db.messages.get_all_folders() + assert any(f["name"] == name for f in folders) + + @given( + original=st_nasty_text.filter(lambda x: len(x) > 0), + renamed=st_nasty_text.filter(lambda x: len(x) > 0), + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=30, + ) + def test_rename_folder(self, db, original, renamed): + db.provider.execute("DELETE FROM lxmf_folders") + cursor = db.messages.create_folder(original) + folder_id = cursor.lastrowid + db.messages.rename_folder(folder_id, renamed) + folders = db.messages.get_all_folders() + found = [f for f in folders if f["id"] == folder_id] + assert len(found) == 1 + assert found[0]["name"] == renamed + + +# =================================================================== +# MessageHandler — search +# =================================================================== + +class TestMessageHandlerFuzzing: + + @given(search=st_search_term) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=80, + ) + def test_search_messages_never_crashes(self, handler, search): + results = handler.search_messages("local_hash", search) + assert isinstance(results, list) + + @given(search=st_search_term) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=60, + ) + def test_get_conversations_search_never_crashes(self, handler, search): + results = handler.get_conversations("local_hash", search=search) + assert isinstance(results, list) + + @given( + search=st_search_term, + folder_id=st.one_of(st.none(), st.integers(min_value=0, max_value=1000)), + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=40, + ) + def test_get_conversations_with_filters(self, handler, search, folder_id): + results = handler.get_conversations( + "local_hash", + search=search, + folder_id=folder_id, + filter_unread=True, + ) + assert isinstance(results, list) + + @given( + dest=st_hex_hash, + after_id=st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000)), + before_id=st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000)), + ) + @settings( + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + max_examples=40, + ) + def test_get_conversation_messages_never_crashes(self, handler, dest, after_id, before_id): + results = handler.get_conversation_messages( + "local_hash", dest, after_id=after_id, before_id=before_id, + ) + assert isinstance(results, list) + + +# =================================================================== +# _safe_href — URL sanitization +# =================================================================== + +class TestSafeHrefFuzzing: + + @given(url=st.text(max_size=500)) + @settings(deadline=None, max_examples=200) + def test_safe_href_never_returns_dangerous_protocol(self, url): + result = _safe_href(url) + lower = result.strip().lower() + assert not lower.startswith("javascript:") + assert not lower.startswith("data:") + assert not lower.startswith("vbscript:") + assert not lower.startswith("file:") + + @pytest.mark.parametrize( + "url", + [ + "javascript:alert(1)", + "JAVASCRIPT:alert(1)", + " javascript:alert(1) ", + "jAvAsCrIpT:alert(document.cookie)", + "data:text/html,", + "DATA:text/html;base64,PHNjcmlwdD4=", + "vbscript:MsgBox('xss')", + "file:///etc/passwd", + ], + ) + def test_safe_href_blocks_xss_vectors(self, url): + assert _safe_href(url) == "#" + + @pytest.mark.parametrize( + "url", + [ + "https://example.com", + "http://example.com", + "/relative/path", + "#anchor", + "mailto:test@example.com", + ], + ) + def test_safe_href_allows_safe_urls(self, url): + assert _safe_href(url) == url + + @given( + url=st.from_regex( + r"(javascript|data|vbscript|file):[A-Za-z0-9()=;,/+]+", + fullmatch=True, + ) + ) + @settings(deadline=None, max_examples=100) + def test_safe_href_blocks_all_generated_dangerous_urls(self, url): + assert _safe_href(url) == "#" + + @given(scheme=st.text(min_size=1, max_size=20).filter(lambda s: ":" not in s and "/" not in s)) + @settings(deadline=None, max_examples=60) + def test_safe_href_blocks_unknown_schemes(self, scheme): + """Any unknown scheme like 'foo:bar' should be blocked.""" + url = f"{scheme}:something" + result = _safe_href(url) + lower = url.strip().lower() + if any(lower.startswith(p) for p in ("https://", "http://", "/", "#", "mailto:")): + assert result == url + else: + assert result == "#" + + +# =================================================================== +# parse_bool_query_param +# =================================================================== + +class TestParseBoolQueryParam: + + @given(value=st.text(max_size=100)) + @settings(deadline=None, max_examples=200) + def test_never_crashes_and_returns_bool(self, value): + result = parse_bool_query_param(value) + assert isinstance(result, bool) + + @pytest.mark.parametrize("value,expected", [ + ("1", True), + ("true", True), + ("True", True), + ("TRUE", True), + ("yes", True), + ("YES", True), + ("on", True), + ("ON", True), + ("0", False), + ("false", False), + ("no", False), + ("off", False), + ("", False), + ("maybe", False), + (None, False), + ]) + def test_known_values(self, value, expected): + assert parse_bool_query_param(value) is expected + + +# =================================================================== +# message_fields_have_attachments +# =================================================================== + +class TestMessageFieldsHaveAttachments: + + @given(data=st.text(max_size=500)) + @settings(deadline=None, max_examples=100) + def test_never_crashes_on_arbitrary_text(self, data): + result = message_fields_have_attachments(data) + assert isinstance(result, bool) + + @given( + obj=st.dictionaries( + st.text(max_size=20), + st.one_of(st.text(max_size=50), st.integers(), st.booleans(), st.none()), + max_size=10, + ) + ) + @settings(deadline=None, max_examples=80) + def test_never_crashes_on_arbitrary_json_objects(self, obj): + result = message_fields_have_attachments(json.dumps(obj)) + assert isinstance(result, bool) + + @pytest.mark.parametrize("fields_json,expected", [ + (None, False), + ("", False), + ("{}", False), + ("not json", False), + ('{"image": "base64data"}', True), + ('{"audio": "base64data"}', True), + ('{"file_attachments": [{"name": "f.txt"}]}', True), + ('{"file_attachments": []}', False), + ('{"other_field": "value"}', False), + ]) + def test_known_cases(self, fields_json, expected): + assert message_fields_have_attachments(fields_json) is expected + + @given( + data=st.recursive( + st.one_of(st.text(max_size=20), st.integers(), st.booleans(), st.none()), + lambda children: st.lists(children, max_size=5) + | st.dictionaries(st.text(max_size=10), children, max_size=5), + max_leaves=30, + ) + ) + @settings(deadline=None, max_examples=60) + def test_deeply_nested_json_never_crashes(self, data): + result = message_fields_have_attachments(json.dumps(data)) + assert isinstance(result, bool) diff --git a/tests/backend/test_performance_hotpaths.py b/tests/backend/test_performance_hotpaths.py new file mode 100644 index 0000000..e5cfd09 --- /dev/null +++ b/tests/backend/test_performance_hotpaths.py @@ -0,0 +1,598 @@ +"""Performance regression tests for the critical hot paths. + +Focus areas (user priority): + - NomadNet browser: load announces, search announces, favourites + - Messages: load conversations, search messages, load conversation messages, + upsert messages (drafts) + +Metrics collected: + - ops/sec throughput + - p50 / p95 / p99 latency + - Concurrent writer contention + - LIKE-search scaling + +All tests have hard assertions so regressions fail CI. +""" + +import os +import secrets +import shutil +import statistics +import tempfile +import threading +import time +import unittest +from unittest.mock import MagicMock + +from meshchatx.src.backend.announce_manager import AnnounceManager +from meshchatx.src.backend.database import Database +from meshchatx.src.backend.message_handler import MessageHandler + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def percentile(data, pct): + """Return the pct-th percentile of sorted data.""" + if not data: + return 0 + s = sorted(data) + k = (len(s) - 1) * (pct / 100) + f = int(k) + c = f + 1 + if c >= len(s): + return s[f] + return s[f] + (k - f) * (s[c] - s[f]) + + +def timed_call(fn, *args, **kwargs): + """Call fn and return (result, duration_ms).""" + t0 = time.perf_counter() + result = fn(*args, **kwargs) + return result, (time.perf_counter() - t0) * 1000 + + +def latency_report(name, durations_ms): + """Print and return latency stats.""" + p50 = percentile(durations_ms, 50) + p95 = percentile(durations_ms, 95) + p99 = percentile(durations_ms, 99) + avg = statistics.mean(durations_ms) + ops = 1000 / avg if avg > 0 else float("inf") + print( + f" {name}: avg={avg:.2f}ms p50={p50:.2f}ms p95={p95:.2f}ms " + f"p99={p99:.2f}ms ops/s={ops:.0f}" + ) + return {"avg": avg, "p50": p50, "p95": p95, "p99": p99, "ops": ops} + + +def make_message(peer_hash, i, content_size=100): + return { + "hash": secrets.token_hex(16), + "source_hash": peer_hash, + "destination_hash": "local_hash_0" * 2, + "peer_hash": peer_hash, + "state": "delivered", + "progress": 1.0, + "is_incoming": i % 2, + "method": "direct", + "delivery_attempts": 1, + "next_delivery_attempt_at": None, + "title": f"Message title {i} " + secrets.token_hex(8), + "content": f"Content body {i} " + "x" * content_size, + "fields": "{}", + "timestamp": time.time() - i, + "rssi": -50, + "snr": 5.0, + "quality": 3, + "is_spam": 0, + "reply_to_hash": None, + } + + +def make_announce(i): + return { + "destination_hash": secrets.token_hex(16), + "aspect": "lxmf.delivery" if i % 3 != 0 else "lxst.telephony", + "identity_hash": secrets.token_hex(16), + "identity_public_key": "pubkey_" + secrets.token_hex(8), + "app_data": "appdata_" + secrets.token_hex(16), + "rssi": -50 + (i % 30), + "snr": 5.0, + "quality": i % 10, + } + + +# --------------------------------------------------------------------------- +# Test class +# --------------------------------------------------------------------------- + +class TestPerformanceHotPaths(unittest.TestCase): + + NUM_MESSAGES = 10_000 + NUM_PEERS = 200 + NUM_ANNOUNCES = 5_000 + NUM_FAVOURITES = 100 + NUM_CONTACTS = 50 + + @classmethod + def setUpClass(cls): + cls.test_dir = tempfile.mkdtemp() + cls.db_path = os.path.join(cls.test_dir, "perf_hotpaths.db") + cls.db = Database(cls.db_path) + cls.db.initialize() + cls.handler = MessageHandler(cls.db) + cls.announce_mgr = AnnounceManager(cls.db) + + cls._seed_data() + + @classmethod + def tearDownClass(cls): + cls.db.close_all() + shutil.rmtree(cls.test_dir, ignore_errors=True) + + @classmethod + def _seed_data(cls): + print(f"\n--- Seeding test data ---") + + # Peers + cls.peer_hashes = [secrets.token_hex(16) for _ in range(cls.NUM_PEERS)] + cls.heavy_peer = cls.peer_hashes[0] + + # Messages: distribute across peers, heavy_peer gets 2000 + print(f" Seeding {cls.NUM_MESSAGES} messages across {cls.NUM_PEERS} peers...") + t0 = time.perf_counter() + with cls.db.provider: + for i in range(cls.NUM_MESSAGES): + if i < 2000: + peer = cls.heavy_peer + else: + peer = cls.peer_hashes[i % cls.NUM_PEERS] + cls.db.messages.upsert_lxmf_message(make_message(peer, i)) + print(f" Done in {(time.perf_counter() - t0) * 1000:.0f}ms") + + # Announces + print(f" Seeding {cls.NUM_ANNOUNCES} announces...") + cls.announce_hashes = [] + t0 = time.perf_counter() + with cls.db.provider: + for i in range(cls.NUM_ANNOUNCES): + data = make_announce(i) + cls.announce_hashes.append(data["destination_hash"]) + cls.db.announces.upsert_announce(data) + print(f" Done in {(time.perf_counter() - t0) * 1000:.0f}ms") + + # Favourites + print(f" Seeding {cls.NUM_FAVOURITES} favourites...") + with cls.db.provider: + for i in range(cls.NUM_FAVOURITES): + cls.db.announces.upsert_favourite( + cls.announce_hashes[i], + f"Fav Node {i}", + "lxmf.delivery", + ) + + # Contacts (for JOIN benchmarks) + print(f" Seeding {cls.NUM_CONTACTS} contacts...") + with cls.db.provider: + for i in range(cls.NUM_CONTACTS): + cls.db.contacts.add_contact( + name=f"Contact {i}", + remote_identity_hash=cls.peer_hashes[i % cls.NUM_PEERS], + lxmf_address=cls.peer_hashes[(i + 1) % cls.NUM_PEERS], + ) + + print("--- Seeding complete ---\n") + + # =================================================================== + # ANNOUNCES — load, search, count + # =================================================================== + + def test_announce_load_filtered_latency(self): + """Load announces filtered by aspect with pagination — the NomadNet browser default view.""" + print("\n[Announce] Filtered load (aspect + pagination):") + durations = [] + offsets = [0, 100, 500, 1000, 2000] + for offset in offsets: + _, ms = timed_call( + self.announce_mgr.get_filtered_announces, + aspect="lxmf.delivery", + limit=50, + offset=offset, + ) + durations.append(ms) + + stats = latency_report("filtered_load", durations) + self.assertLess(stats["p99"], 100, "Announce filtered load p99 > 100ms") + + def test_announce_search_latency(self): + """Search announces by destination/identity hash substring.""" + print("\n[Announce] LIKE search:") + search_terms = [ + secrets.token_hex(4), + "abc", + self.announce_hashes[0][:8], + self.announce_hashes[2500][:10], + "nonexistent_term_xyz", + ] + durations = [] + for term in search_terms: + _, ms = timed_call( + self.announce_mgr.get_filtered_announces, + aspect="lxmf.delivery", + query=term, + limit=50, + offset=0, + ) + durations.append(ms) + + stats = latency_report("search", durations) + self.assertLess(stats["p95"], 150, "Announce search p95 > 150ms") + + def test_announce_search_with_blocked(self): + """Search with a block-list — simulates real NomadNet browser filtering.""" + print("\n[Announce] Search with blocked list:") + blocked = [secrets.token_hex(16) for _ in range(50)] + durations = [] + for _ in range(20): + _, ms = timed_call( + self.announce_mgr.get_filtered_announces, + aspect="lxmf.delivery", + query="abc", + blocked_identity_hashes=blocked, + limit=50, + offset=0, + ) + durations.append(ms) + + stats = latency_report("search+blocked", durations) + self.assertLess(stats["p95"], 200, "Announce search+blocked p95 > 200ms") + + def test_announce_count_latency(self): + """Count announces (used for pagination total).""" + print("\n[Announce] Count:") + durations = [] + for _ in range(30): + _, ms = timed_call( + self.announce_mgr.get_filtered_announces_count, + aspect="lxmf.delivery", + ) + durations.append(ms) + + stats = latency_report("count", durations) + self.assertLess(stats["p95"], 100, "Announce count p95 > 100ms") + + # =================================================================== + # FAVOURITES + # =================================================================== + + def test_favourites_load_latency(self): + """Load all favourites — typically displayed in sidebar.""" + print("\n[Favourites] Load all:") + durations = [] + for _ in range(50): + _, ms = timed_call(self.db.announces.get_favourites, "lxmf.delivery") + durations.append(ms) + + stats = latency_report("load_favs", durations) + self.assertLess(stats["p95"], 20, "Favourites load p95 > 20ms") + + def test_favourite_upsert_throughput(self): + """Measure upsert throughput for favourites.""" + print("\n[Favourites] Upsert throughput:") + durations = [] + for i in range(100): + dest = secrets.token_hex(16) + _, ms = timed_call( + self.db.announces.upsert_favourite, + dest, f"Bench Fav {i}", "lxmf.delivery", + ) + durations.append(ms) + + stats = latency_report("upsert_fav", durations) + self.assertGreater(stats["ops"], 500, "Favourite upsert < 500 ops/s") + + # =================================================================== + # CONVERSATIONS — load, search + # =================================================================== + + def test_conversations_load_latency(self): + """Load conversation list — the main messages sidebar query.""" + print("\n[Conversations] Load list (with JOINs):") + durations = [] + for _ in range(20): + _, ms = timed_call( + self.handler.get_conversations, + "local_hash", + limit=50, + offset=0, + ) + durations.append(ms) + + stats = latency_report("load_conversations", durations) + self.assertLess(stats["p95"], 200, "Conversation list p95 > 200ms") + + def test_conversations_search_latency(self): + """Search conversations — LIKE across titles, content, peer hashes.""" + print("\n[Conversations] Search:") + terms = ["Message title 5", "Content body", "abc", "zzz_nope", self.heavy_peer[:8]] + durations = [] + for term in terms: + _, ms = timed_call( + self.handler.get_conversations, + "local_hash", + search=term, + limit=50, + ) + durations.append(ms) + + stats = latency_report("search_conversations", durations) + self.assertLess(stats["p95"], 500, "Conversation search p95 > 500ms") + + def test_conversations_load_paginated(self): + """Paginate through conversation list at various offsets.""" + print("\n[Conversations] Paginated load:") + durations = [] + for offset in [0, 20, 50, 100, 150]: + _, ms = timed_call( + self.handler.get_conversations, + "local_hash", + limit=20, + offset=offset, + ) + durations.append(ms) + + stats = latency_report("paginated_conversations", durations) + self.assertLess(stats["p95"], 300, "Paginated conversations p95 > 300ms") + + # =================================================================== + # MESSAGES — load, search, upsert (drafts) + # =================================================================== + + def test_message_load_latency(self): + """Load messages for a single conversation (heavy peer with 2000 msgs).""" + print("\n[Messages] Load conversation messages:") + durations = [] + offsets = [0, 100, 500, 1000, 1900] + for offset in offsets: + result, ms = timed_call( + self.handler.get_conversation_messages, + "local_hash", + self.heavy_peer, + limit=50, + offset=offset, + ) + durations.append(ms) + self.assertEqual(len(result), 50) + + stats = latency_report("load_messages", durations) + self.assertLess(stats["p99"], 50, "Message load p99 > 50ms") + + def test_message_search_latency(self): + """Search messages across all conversations — the global search.""" + print("\n[Messages] Global search:") + terms = [ + "Message title 100", + "Content body 5000", + secrets.token_hex(4), + "nonexistent_xyz_123", + ] + durations = [] + for term in terms: + _, ms = timed_call( + self.handler.search_messages, + "local_hash", + term, + ) + durations.append(ms) + + stats = latency_report("search_messages", durations) + self.assertLess(stats["p95"], 300, "Message search p95 > 300ms") + + def test_message_upsert_throughput(self): + """Measure message upsert throughput — simulates saving drafts rapidly.""" + print("\n[Messages] Upsert throughput (draft saves):") + durations = [] + peer = secrets.token_hex(16) + for i in range(200): + msg = make_message(peer, i + 100000) + _, ms = timed_call(self.db.messages.upsert_lxmf_message, msg) + durations.append(ms) + + stats = latency_report("upsert_message", durations) + self.assertGreater(stats["ops"], 300, "Message upsert < 300 ops/s") + self.assertLess(stats["p95"], 10, "Message upsert p95 > 10ms") + + def test_message_upsert_update_throughput(self): + """Measure message UPDATE throughput — re-saving existing messages (state changes).""" + print("\n[Messages] Update existing messages:") + peer = secrets.token_hex(16) + msgs = [] + for i in range(100): + msg = make_message(peer, i + 200000) + self.db.messages.upsert_lxmf_message(msg) + msgs.append(msg) + + durations = [] + for msg in msgs: + msg["state"] = "failed" + msg["content"] = "Updated content " + secrets.token_hex(16) + _, ms = timed_call(self.db.messages.upsert_lxmf_message, msg) + durations.append(ms) + + stats = latency_report("update_message", durations) + self.assertGreater(stats["ops"], 300, "Message update < 300 ops/s") + + # =================================================================== + # CONCURRENT WRITERS — contention stress + # =================================================================== + + def test_concurrent_message_writers(self): + """Multiple threads inserting messages simultaneously.""" + print("\n[Concurrency] Message writers:") + num_threads = 8 + msgs_per_thread = 100 + errors = [] + all_durations = [] + lock = threading.Lock() + + def writer(thread_id): + thread_durations = [] + peer = secrets.token_hex(16) + for i in range(msgs_per_thread): + msg = make_message(peer, thread_id * 10000 + i) + try: + _, ms = timed_call(self.db.messages.upsert_lxmf_message, msg) + thread_durations.append(ms) + except Exception as e: + errors.append(str(e)) + with lock: + all_durations.extend(thread_durations) + + threads = [threading.Thread(target=writer, args=(t,)) for t in range(num_threads)] + t0 = time.perf_counter() + for t in threads: + t.start() + for t in threads: + t.join() + wall_ms = (time.perf_counter() - t0) * 1000 + + total_ops = num_threads * msgs_per_thread + throughput = total_ops / (wall_ms / 1000) + print(f" Wall time: {wall_ms:.0f}ms for {total_ops} inserts ({throughput:.0f} ops/s)") + stats = latency_report("concurrent_write", all_durations) + + self.assertEqual(len(errors), 0, f"Writer errors: {errors[:5]}") + self.assertGreater(throughput, 100, "Concurrent write throughput < 100 ops/s") + + def test_concurrent_announce_writers(self): + """Multiple threads upserting announces simultaneously.""" + print("\n[Concurrency] Announce writers:") + num_threads = 6 + announces_per_thread = 100 + errors = [] + all_durations = [] + lock = threading.Lock() + + def writer(thread_id): + thread_durations = [] + for i in range(announces_per_thread): + data = make_announce(thread_id * 10000 + i) + try: + _, ms = timed_call(self.db.announces.upsert_announce, data) + thread_durations.append(ms) + except Exception as e: + errors.append(str(e)) + with lock: + all_durations.extend(thread_durations) + + threads = [threading.Thread(target=writer, args=(t,)) for t in range(num_threads)] + t0 = time.perf_counter() + for t in threads: + t.start() + for t in threads: + t.join() + wall_ms = (time.perf_counter() - t0) * 1000 + + total_ops = num_threads * announces_per_thread + throughput = total_ops / (wall_ms / 1000) + print(f" Wall time: {wall_ms:.0f}ms for {total_ops} upserts ({throughput:.0f} ops/s)") + stats = latency_report("concurrent_announce_write", all_durations) + + self.assertEqual(len(errors), 0, f"Writer errors: {errors[:5]}") + self.assertGreater(throughput, 100, "Concurrent announce write < 100 ops/s") + + def test_concurrent_read_write_contention(self): + """Writers inserting while readers query — simulates real app usage.""" + print("\n[Contention] Mixed read/write:") + num_writers = 4 + num_readers = 4 + ops_per_thread = 50 + write_errors = [] + read_errors = [] + write_durations = [] + read_durations = [] + lock = threading.Lock() + + def writer(thread_id): + local_durs = [] + peer = secrets.token_hex(16) + for i in range(ops_per_thread): + msg = make_message(peer, thread_id * 10000 + i) + try: + _, ms = timed_call(self.db.messages.upsert_lxmf_message, msg) + local_durs.append(ms) + except Exception as e: + write_errors.append(str(e)) + with lock: + write_durations.extend(local_durs) + + def reader(_thread_id): + local_durs = [] + for _ in range(ops_per_thread): + try: + _, ms = timed_call( + self.handler.get_conversations, + "local_hash", + limit=20, + ) + local_durs.append(ms) + except Exception as e: + read_errors.append(str(e)) + with lock: + read_durations.extend(local_durs) + + writers = [threading.Thread(target=writer, args=(t,)) for t in range(num_writers)] + readers = [threading.Thread(target=reader, args=(t,)) for t in range(num_readers)] + + t0 = time.perf_counter() + for t in writers + readers: + t.start() + for t in writers + readers: + t.join() + wall_ms = (time.perf_counter() - t0) * 1000 + + print(f" Wall time: {wall_ms:.0f}ms") + latency_report("contention_writes", write_durations) + latency_report("contention_reads", read_durations) + + self.assertEqual(len(write_errors), 0, f"Write errors: {write_errors[:5]}") + self.assertEqual(len(read_errors), 0, f"Read errors: {read_errors[:5]}") + + # =================================================================== + # LIKE SEARCH SCALING — how search degrades with data size + # =================================================================== + + def test_like_search_scaling(self): + """Measure how LIKE search scales across different table sizes. + This catches missing FTS indexes or query plan regressions.""" + print("\n[Scaling] LIKE search across data sizes:") + + # Message search on the existing 10k dataset + _, ms_msg = timed_call( + self.handler.search_messages, + "local_hash", + "Content body", + ) + print(f" Message LIKE search ({self.NUM_MESSAGES} rows): {ms_msg:.2f}ms") + self.assertLess(ms_msg, 500, "Message LIKE search > 500ms on 10k rows") + + # Announce search on the existing 5k dataset + _, ms_ann = timed_call( + self.announce_mgr.get_filtered_announces, + aspect="lxmf.delivery", + query="abc", + limit=50, + ) + print(f" Announce LIKE search ({self.NUM_ANNOUNCES} rows): {ms_ann:.2f}ms") + self.assertLess(ms_ann, 200, "Announce LIKE search > 200ms on 5k rows") + + # Contacts search + _, ms_con = timed_call(self.db.contacts.get_contacts, search="Contact") + print(f" Contacts LIKE search ({self.NUM_CONTACTS} rows): {ms_con:.2f}ms") + self.assertLess(ms_con, 50, "Contacts LIKE search > 50ms") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/backend/test_sql_injection_fixes.py b/tests/backend/test_sql_injection_fixes.py new file mode 100644 index 0000000..1d4e8d4 --- /dev/null +++ b/tests/backend/test_sql_injection_fixes.py @@ -0,0 +1,392 @@ +"""Tests confirming SQL-injection fixes in the raw-SQL database layer. + +Covers: + - ATTACH DATABASE path escaping (single-quote doubling) in LegacyMigrator + - Column-name identifier filtering during legacy migration + - _validate_identifier / _ensure_column rejection in DatabaseSchema +""" + +import os +import re +import sqlite3 +import tempfile + +import pytest +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st + +from meshchatx.src.backend.database.legacy_migrator import LegacyMigrator +from meshchatx.src.backend.database.provider import DatabaseProvider +from meshchatx.src.backend.database.schema import DatabaseSchema, _validate_identifier + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def db_env(tmp_path): + """Provide an initialized DatabaseProvider + Schema in a temp directory.""" + db_path = str(tmp_path / "current.db") + provider = DatabaseProvider(db_path) + schema = DatabaseSchema(provider) + schema.initialize() + yield provider, schema, tmp_path + provider.close() + + +def _make_legacy_db(legacy_dir, identity_hash, tables_sql): + """Create a legacy database with the given CREATE TABLE + INSERT statements.""" + identity_dir = os.path.join(legacy_dir, "identities", identity_hash) + os.makedirs(identity_dir, exist_ok=True) + db_path = os.path.join(identity_dir, "database.db") + conn = sqlite3.connect(db_path) + for sql in tables_sql: + conn.execute(sql) + conn.commit() + conn.close() + return db_path + + +# --------------------------------------------------------------------------- +# 1. ATTACH DATABASE — single-quote escaping +# --------------------------------------------------------------------------- + +class TestAttachDatabasePathEscaping: + + def test_path_without_quotes_migrates_normally(self, db_env): + provider, _schema, tmp_path = db_env + legacy_dir = str(tmp_path / "legacy_normal") + identity_hash = "aabbccdd" + _make_legacy_db( + legacy_dir, + identity_hash, + [ + "CREATE TABLE config (key TEXT UNIQUE, value TEXT)", + "INSERT INTO config (key, value) VALUES ('k1', 'v1')", + ], + ) + + migrator = LegacyMigrator(provider, legacy_dir, identity_hash) + assert migrator.migrate() is True + + row = provider.fetchone("SELECT value FROM config WHERE key = 'k1'") + assert row is not None + assert row["value"] == "v1" + + def test_path_with_single_quote_does_not_crash(self, db_env): + """A path containing a single quote must not cause SQL injection or crash.""" + provider, _schema, tmp_path = db_env + + quoted_dir = tmp_path / "it's_a_test" + quoted_dir.mkdir(parents=True, exist_ok=True) + legacy_dir = str(quoted_dir) + identity_hash = "aabbccdd" + _make_legacy_db( + legacy_dir, + identity_hash, + [ + "CREATE TABLE config (key TEXT UNIQUE, value TEXT)", + "INSERT INTO config (key, value) VALUES ('q1', 'quoted_val')", + ], + ) + + migrator = LegacyMigrator(provider, legacy_dir, identity_hash) + result = migrator.migrate() + assert result is True + + row = provider.fetchone("SELECT value FROM config WHERE key = 'q1'") + assert row is not None + assert row["value"] == "quoted_val" + + def test_path_with_multiple_quotes(self, db_env): + """Multiple single quotes in the path are all escaped.""" + provider, _schema, tmp_path = db_env + + weird_dir = tmp_path / "a'b'c" + weird_dir.mkdir(parents=True, exist_ok=True) + legacy_dir = str(weird_dir) + identity_hash = "11223344" + _make_legacy_db( + legacy_dir, + identity_hash, + [ + "CREATE TABLE config (key TEXT UNIQUE, value TEXT)", + "INSERT INTO config (key, value) VALUES ('mq', 'multi_quote')", + ], + ) + + migrator = LegacyMigrator(provider, legacy_dir, identity_hash) + assert migrator.migrate() is True + + row = provider.fetchone("SELECT value FROM config WHERE key = 'mq'") + assert row is not None + assert row["value"] == "multi_quote" + + def test_path_with_sql_injection_attempt(self, db_env): + """A path crafted to look like SQL injection is safely escaped.""" + provider, _schema, tmp_path = db_env + + evil_dir = tmp_path / "'; DROP TABLE config; --" + evil_dir.mkdir(parents=True, exist_ok=True) + legacy_dir = str(evil_dir) + identity_hash = "deadbeef" + _make_legacy_db( + legacy_dir, + identity_hash, + [ + "CREATE TABLE config (key TEXT UNIQUE, value TEXT)", + "INSERT INTO config (key, value) VALUES ('evil', 'nope')", + ], + ) + + migrator = LegacyMigrator(provider, legacy_dir, identity_hash) + migrator.migrate() + + tables = provider.fetchall( + "SELECT name FROM sqlite_master WHERE type='table' AND name='config'" + ) + assert len(tables) > 0, "config table must still exist after injection attempt" + + +# --------------------------------------------------------------------------- +# 2. Legacy migrator — malicious column names filtered out +# --------------------------------------------------------------------------- + +class TestLegacyColumnFiltering: + + def test_normal_columns_migrate(self, db_env): + """Standard column names pass through the identifier filter.""" + provider, _schema, tmp_path = db_env + legacy_dir = str(tmp_path / "legacy_cols") + identity_hash = "aabb0011" + _make_legacy_db( + legacy_dir, + identity_hash, + [ + "CREATE TABLE config (key TEXT UNIQUE, value TEXT)", + "INSERT INTO config (key, value) VALUES ('c1', 'ok')", + ], + ) + + migrator = LegacyMigrator(provider, legacy_dir, identity_hash) + assert migrator.migrate() is True + + row = provider.fetchone("SELECT value FROM config WHERE key = 'c1'") + assert row is not None + + def test_malicious_column_name_is_skipped(self, db_env): + """A column with SQL metacharacters in its name must be silently skipped.""" + provider, _schema, tmp_path = db_env + legacy_dir = str(tmp_path / "legacy_evil_col") + identity_hash = "cc00dd00" + + identity_dir = os.path.join(legacy_dir, "identities", identity_hash) + os.makedirs(identity_dir, exist_ok=True) + db_path = os.path.join(identity_dir, "database.db") + conn = sqlite3.connect(db_path) + conn.execute( + 'CREATE TABLE config (key TEXT UNIQUE, value TEXT, "key; DROP TABLE config" TEXT)' + ) + conn.execute( + "INSERT INTO config (key, value) VALUES ('safe', 'data')" + ) + conn.commit() + conn.close() + + migrator = LegacyMigrator(provider, legacy_dir, identity_hash) + assert migrator.migrate() is True + + row = provider.fetchone("SELECT value FROM config WHERE key = 'safe'") + assert row is not None + assert row["value"] == "data" + + def test_column_with_parentheses_is_skipped(self, db_env): + """Columns with () in the name are rejected by the identifier regex.""" + provider, _schema, tmp_path = db_env + legacy_dir = str(tmp_path / "legacy_parens_col") + identity_hash = "ee00ff00" + + identity_dir = os.path.join(legacy_dir, "identities", identity_hash) + os.makedirs(identity_dir, exist_ok=True) + db_path = os.path.join(identity_dir, "database.db") + conn = sqlite3.connect(db_path) + conn.execute( + 'CREATE TABLE config (key TEXT UNIQUE, value TEXT, "evil()" TEXT)' + ) + conn.execute( + "INSERT INTO config (key, value) VALUES ('p1', 'parens')" + ) + conn.commit() + conn.close() + + migrator = LegacyMigrator(provider, legacy_dir, identity_hash) + assert migrator.migrate() is True + + row = provider.fetchone("SELECT value FROM config WHERE key = 'p1'") + assert row is not None + + +# --------------------------------------------------------------------------- +# 3. _validate_identifier — unit tests +# --------------------------------------------------------------------------- + +class TestValidateIdentifier: + + @pytest.mark.parametrize( + "name", + [ + "config", + "lxmf_messages", + "A", + "_private", + "Column123", + "a_b_c_d", + ], + ) + def test_valid_identifiers_pass(self, name): + assert _validate_identifier(name) == name + + @pytest.mark.parametrize( + "name", + [ + "", + "123abc", + "table name", + "col;drop", + "a'b", + 'a"b', + "col()", + "x--y", + "a,b", + "hello\nworld", + "tab\there", + "col/**/name", + ], + ) + def test_invalid_identifiers_raise(self, name): + with pytest.raises(ValueError, match="Invalid SQL"): + _validate_identifier(name) + + +# --------------------------------------------------------------------------- +# 4. _ensure_column — rejects injection via table/column names +# --------------------------------------------------------------------------- + +class TestEnsureColumnInjection: + + def test_ensure_column_rejects_malicious_table_name(self, db_env): + _provider, schema, _tmp_path = db_env + with pytest.raises(ValueError, match="Invalid SQL table name"): + schema._ensure_column("config; DROP TABLE config", "new_col", "TEXT") + + def test_ensure_column_rejects_malicious_column_name(self, db_env): + _provider, schema, _tmp_path = db_env + with pytest.raises(ValueError, match="Invalid SQL column name"): + schema._ensure_column("config", "col; DROP TABLE config", "TEXT") + + def test_ensure_column_works_for_valid_names(self, db_env): + _provider, schema, _tmp_path = db_env + result = schema._ensure_column("config", "test_new_col", "TEXT") + assert result is True + + def test_ensure_column_idempotent(self, db_env): + _provider, schema, _tmp_path = db_env + schema._ensure_column("config", "idempotent_col", "TEXT") + result = schema._ensure_column("config", "idempotent_col", "TEXT") + assert result is True + + +# --------------------------------------------------------------------------- +# 5. Property-based tests — identifier regex +# --------------------------------------------------------------------------- + +_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +@given(name=st.text(min_size=1, max_size=80)) +@settings(deadline=None) +def test_validate_identifier_never_allows_sql_metacharacters(name): + """No string accepted by _validate_identifier contains SQL metacharacters.""" + try: + _validate_identifier(name) + except ValueError: + return + + assert ";" not in name + assert "'" not in name + assert '"' not in name + assert "(" not in name + assert ")" not in name + assert " " not in name + assert "-" not in name + assert "/" not in name + assert "\\" not in name + assert "\n" not in name + assert "\r" not in name + assert "\t" not in name + assert "," not in name + assert _IDENTIFIER_RE.match(name) + + +@given(name=st.from_regex(r"[A-Za-z_][A-Za-z0-9_]{0,30}", fullmatch=True)) +@settings(deadline=None) +def test_validate_identifier_accepts_all_valid_identifiers(name): + """Every string matching the identifier pattern is accepted.""" + assert _validate_identifier(name) == name + + +@given( + name=st.text( + alphabet=st.sampled_from(list(";'\"()- \t\n\r,/*")), + min_size=1, + max_size=30, + ) +) +@settings(deadline=None) +def test_validate_identifier_rejects_pure_metacharacter_strings(name): + """Strings composed entirely of SQL metacharacters are always rejected.""" + with pytest.raises(ValueError): + _validate_identifier(name) + + +# --------------------------------------------------------------------------- +# 6. ATTACH path escaping — property-based +# --------------------------------------------------------------------------- + +@given( + path_segment=st.text( + alphabet=st.characters( + whitelist_categories=("L", "N", "P", "S", "Z"), + ), + min_size=1, + max_size=60, + ) +) +@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow]) +def test_attach_path_escaping_never_breaks_sql(path_segment): + """The quote-doubling escaping produces a string that SQLite can parse + without breaking out of the literal, regardless of the path content.""" + safe = path_segment.replace("'", "''") + sql = f"ATTACH DATABASE '{safe}' AS test_alias" + + assert sql.count("ATTACH DATABASE '") == 1 + + after_open = sql.split("ATTACH DATABASE '", 1)[1] + in_literal = True + i = 0 + while i < len(after_open): + if after_open[i] == "'": + if i + 1 < len(after_open) and after_open[i + 1] == "'": + i += 2 + continue + else: + in_literal = False + remainder = after_open[i + 1 :] + break + i += 1 + + if not in_literal: + assert remainder.strip() == "AS test_alias", ( + f"Unexpected SQL after literal end: {remainder!r}" + )