Add SQL injection prevention measures in database migration and schema validation

- Implemented identifier validation regex in legacy migrator and schema.
- Enhanced database path handling to escape single quotes during ATTACH DATABASE.
- Added tests for SQL injection scenarios, ensuring robustness against malicious inputs.
- Introduced fuzz tests for DAO layers to cover edge cases and improve overall test coverage.
This commit is contained in:
Sudo-Ivan
2026-03-06 03:08:28 -06:00
parent 905a4592be
commit 9d7ae2017b
5 changed files with 1731 additions and 14 deletions
@@ -1,4 +1,7 @@
import os
import re
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
class LegacyMigrator:
@@ -73,7 +76,8 @@ class LegacyMigrator:
# Attach the legacy database
# We use a randomized alias to avoid collisions
alias = f"legacy_{os.urandom(4).hex()}"
self.provider.execute(f"ATTACH DATABASE '{legacy_path}' AS {alias}")
safe_path = legacy_path.replace("'", "''")
self.provider.execute(f"ATTACH DATABASE '{safe_path}' AS {alias}") # noqa: S608
# Tables that existed in the legacy Peewee version
tables_to_migrate = [
@@ -121,7 +125,9 @@ class LegacyMigrator:
common_columns = [
col
for col in legacy_columns
if col in current_columns and col.lower() != "id"
if col in current_columns
and col.lower() != "id"
and _IDENTIFIER_RE.match(col)
]
if common_columns:
+16 -12
View File
@@ -1,5 +1,16 @@
import re
from .provider import DatabaseProvider
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def _validate_identifier(name: str, label: str = "identifier") -> str:
if not _IDENTIFIER_RE.match(name):
msg = f"Invalid SQL {label}: {name!r}"
raise ValueError(msg)
return name
class DatabaseSchema:
LATEST_VERSION = 38
@@ -28,21 +39,18 @@ class DatabaseSchema:
def _ensure_column(self, table_name, column_name, column_type):
"""Add a column to a table if it doesn't exist."""
# First check if it exists using PRAGMA
_validate_identifier(table_name, "table name")
_validate_identifier(column_name, "column name")
cursor = self.provider.connection.cursor()
try:
cursor.execute(f"PRAGMA table_info({table_name})")
cursor.execute(f"PRAGMA table_info({table_name})") # noqa: S608
columns = [row[1] for row in cursor.fetchall()]
finally:
cursor.close()
if column_name not in columns:
try:
# SQLite has limitations on ALTER TABLE ADD COLUMN:
# 1. Cannot add UNIQUE or PRIMARY KEY columns
# 2. Cannot add columns with non-constant defaults (like CURRENT_TIMESTAMP)
# Strip non-constant defaults if present for the ALTER TABLE statement
stmt_type = column_type
forbidden_defaults = [
"CURRENT_TIMESTAMP",
@@ -51,9 +59,6 @@ class DatabaseSchema:
]
for forbidden in forbidden_defaults:
if f"DEFAULT {forbidden}" in stmt_type.upper():
# Remove the DEFAULT part for the ALTER statement
import re
stmt_type = re.sub(
f"DEFAULT\\s+{forbidden}",
"",
@@ -61,9 +66,8 @@ class DatabaseSchema:
flags=re.IGNORECASE,
).strip()
# Use the connection directly to avoid any middle-ware issues
res = self._safe_execute(
f"ALTER TABLE {table_name} ADD COLUMN {column_name} {stmt_type}",
f"ALTER TABLE {table_name} ADD COLUMN {column_name} {stmt_type}", # noqa: S608
)
return res is not None
except Exception as e:
+717
View File
@@ -0,0 +1,717 @@
"""Property-based and fuzz tests for DAO layers, MessageHandler, and utility functions.
Covers gaps identified in the existing test suite:
- ContactsDAO: add, search, update with adversarial strings
- ConfigDAO: get/set/delete round-trip with arbitrary keys and values
- MiscDAO: spam keywords, notifications, keyboard shortcuts
- TelephoneDAO: add + search with adversarial strings
- VoicemailDAO: add + search
- DebugLogsDAO: insert + search + count consistency
- RingtoneDAO: add + update with adversarial display names
- MapDrawingsDAO: upsert + update with arbitrary JSON data
- MessageDAO: folder create/rename
- MessageHandler: search_messages, get_conversations with adversarial search terms
- _safe_href: URL sanitization (XSS vectors)
- parse_bool_query_param: arbitrary strings
- message_fields_have_attachments: arbitrary JSON
"""
import json
import pytest
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st
from meshchatx.src.backend.database import Database
from meshchatx.src.backend.markdown_renderer import _safe_href
from meshchatx.src.backend.meshchat_utils import (
message_fields_have_attachments,
parse_bool_query_param,
)
from meshchatx.src.backend.message_handler import MessageHandler
# ---------------------------------------------------------------------------
# Strategies
# ---------------------------------------------------------------------------
# Strings that are valid for most text columns but include adversarial chars
st_nasty_text = st.text(
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S", "Z", "C")),
min_size=0,
max_size=300,
)
st_sql_payloads = st.sampled_from([
"'; DROP TABLE config; --",
"' OR '1'='1",
"\" OR \"1\"=\"1",
"1; SELECT * FROM sqlite_master",
"Robert'); DROP TABLE lxmf_messages;--",
"%",
"%%",
"_",
"\x00",
"' UNION SELECT key, value FROM config --",
"NULL",
"1=1",
"admin'--",
"' AND 1=CONVERT(int,(SELECT TOP 1 table_name FROM information_schema.tables))--",
])
st_search_term = st.one_of(st_nasty_text, st_sql_payloads)
st_hex_hash = st.from_regex(r"[0-9a-f]{16,64}", fullmatch=True)
# ---------------------------------------------------------------------------
# Fixture: initialised in-memory database
# ---------------------------------------------------------------------------
@pytest.fixture
def db():
database = Database(":memory:")
database.initialize()
yield database
database.close()
@pytest.fixture
def handler(db):
return MessageHandler(db)
# ===================================================================
# ContactsDAO
# ===================================================================
class TestContactsDAOFuzzing:
@given(
name=st_nasty_text,
identity_hash=st_hex_hash,
lxmf_addr=st.one_of(st.none(), st_hex_hash),
lxst_addr=st.one_of(st.none(), st_hex_hash),
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=60,
)
def test_add_contact_roundtrip(self, db, name, identity_hash, lxmf_addr, lxst_addr):
db.contacts.add_contact(
name=name,
remote_identity_hash=identity_hash,
lxmf_address=lxmf_addr,
lxst_address=lxst_addr,
)
row = db.contacts.get_contact_by_identity_hash(identity_hash)
assert row is not None
assert row["name"] == name
@given(search=st_search_term)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=80,
)
def test_search_contacts_never_crashes(self, db, search):
results = db.contacts.get_contacts(search=search)
assert isinstance(results, list)
count = db.contacts.get_contacts_count(search=search)
assert isinstance(count, int)
assert count >= 0
@given(
name=st_nasty_text,
identity_hash=st_hex_hash,
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_update_contact_never_crashes(self, db, name, identity_hash):
db.contacts.add_contact(name="orig", remote_identity_hash=identity_hash)
row = db.contacts.get_contact_by_identity_hash(identity_hash)
assert row is not None
db.contacts.update_contact(row["id"], name=name)
updated = db.contacts.get_contact(row["id"])
assert updated["name"] == name
# ===================================================================
# ConfigDAO
# ===================================================================
class TestConfigDAOFuzzing:
@given(key=st_nasty_text.filter(lambda x: len(x) > 0), value=st_nasty_text)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=80,
)
def test_config_set_get_roundtrip(self, db, key, value):
db.config.set(key, value)
got = db.config.get(key)
assert got == value
@given(key=st_nasty_text.filter(lambda x: len(x) > 0))
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_config_delete_never_crashes(self, db, key):
db.config.set(key, "temp")
db.config.delete(key)
assert db.config.get(key) is None
@given(key=st_sql_payloads)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
)
def test_config_sql_injection_keys(self, db, key):
db.config.set(key, "injected?")
assert db.config.get(key) == "injected?"
tables = db.provider.fetchall(
"SELECT name FROM sqlite_master WHERE type='table'"
)
table_names = {r["name"] for r in tables}
assert "config" in table_names
# ===================================================================
# MiscDAO — spam keywords, notifications, keyboard shortcuts
# ===================================================================
class TestMiscDAOFuzzing:
@given(keyword=st_nasty_text.filter(lambda x: len(x) > 0))
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=50,
)
def test_spam_keyword_roundtrip(self, db, keyword):
db.misc.add_spam_keyword(keyword)
keywords = db.misc.get_spam_keywords()
assert any(k["keyword"] == keyword for k in keywords)
@given(title=st_nasty_text, content=st_nasty_text)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=50,
)
def test_check_spam_never_crashes(self, db, title, content):
result = db.misc.check_spam_keywords(title, content)
assert isinstance(result, bool)
@given(
ntype=st.sampled_from(["message", "call", "voicemail", "system"]),
remote_hash=st_hex_hash,
title=st_nasty_text,
content=st_nasty_text,
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_add_notification_never_crashes(self, db, ntype, remote_hash, title, content):
db.misc.add_notification(ntype, remote_hash, title, content)
notifications = db.misc.get_notifications()
assert isinstance(notifications, list)
@given(
action=st_nasty_text.filter(lambda x: len(x) > 0),
keys=st_nasty_text,
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_keyboard_shortcut_roundtrip(self, db, action, keys):
identity = "testhash"
db.misc.upsert_keyboard_shortcut(identity, action, keys)
shortcuts = db.misc.get_keyboard_shortcuts(identity)
assert any(s["action"] == action and s["keys"] == keys for s in shortcuts)
@given(
query=st_search_term,
dest=st.one_of(st.none(), st_hex_hash),
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_archived_pages_search_never_crashes(self, db, query, dest):
results = db.misc.get_archived_pages_paginated(
destination_hash=dest,
query=query if query else None,
)
assert isinstance(results, list)
# ===================================================================
# TelephoneDAO
# ===================================================================
class TestTelephoneDAOFuzzing:
@given(
name=st_nasty_text,
identity_hash=st_hex_hash,
status=st.sampled_from(["answered", "missed", "rejected", "busy"]),
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_add_call_history_never_crashes(self, db, name, identity_hash, status):
import time
db.telephone.add_call_history(
remote_identity_hash=identity_hash,
remote_identity_name=name,
is_incoming=True,
status=status,
duration_seconds=42,
timestamp=time.time(),
)
history = db.telephone.get_call_history()
assert isinstance(history, list)
assert len(history) > 0
@given(search=st_search_term)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=60,
)
def test_call_history_search_never_crashes(self, db, search):
results = db.telephone.get_call_history(search=search)
assert isinstance(results, list)
@given(search=st_search_term)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=60,
)
def test_call_recordings_search_never_crashes(self, db, search):
results = db.telephone.get_call_recordings(search=search)
assert isinstance(results, list)
# ===================================================================
# VoicemailDAO
# ===================================================================
class TestVoicemailDAOFuzzing:
@given(
name=st_nasty_text,
identity_hash=st_hex_hash,
filename=st_nasty_text,
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_add_voicemail_never_crashes(self, db, name, identity_hash, filename):
import time
db.voicemails.add_voicemail(
remote_identity_hash=identity_hash,
remote_identity_name=name,
filename=filename,
duration_seconds=10,
timestamp=time.time(),
)
vms = db.voicemails.get_voicemails()
assert isinstance(vms, list)
@given(search=st_search_term)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=60,
)
def test_voicemail_search_never_crashes(self, db, search):
results = db.voicemails.get_voicemails(search=search)
assert isinstance(results, list)
# ===================================================================
# DebugLogsDAO
# ===================================================================
class TestDebugLogsDAOFuzzing:
@given(
level=st.sampled_from(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
module=st_nasty_text,
message=st_nasty_text,
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=50,
)
def test_insert_log_never_crashes(self, db, level, module, message):
db.debug_logs.insert_log(level, module, message)
total = db.debug_logs.get_total_count()
assert total > 0
@given(
search=st_search_term,
level=st.one_of(st.none(), st.sampled_from(["DEBUG", "INFO", "ERROR"])),
module=st.one_of(st.none(), st_nasty_text),
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=60,
)
def test_search_logs_never_crashes(self, db, search, level, module):
results = db.debug_logs.get_logs(search=search, level=level, module=module)
assert isinstance(results, list)
count = db.debug_logs.get_total_count(search=search, level=level, module=module)
assert isinstance(count, int)
assert count >= 0
@given(
search=st_search_term,
level=st.one_of(st.none(), st.sampled_from(["DEBUG", "INFO", "ERROR"])),
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_count_matches_result_length(self, db, search, level):
"""get_total_count must be consistent with len(get_logs) when no limit truncation."""
results = db.debug_logs.get_logs(
search=search, level=level, limit=100000, offset=0,
)
count = db.debug_logs.get_total_count(search=search, level=level)
assert count == len(results)
# ===================================================================
# RingtoneDAO
# ===================================================================
class TestRingtoneDAOFuzzing:
@given(
filename=st_nasty_text.filter(lambda x: len(x) > 0),
display_name=st_nasty_text.filter(lambda x: len(x) > 0),
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_add_ringtone_roundtrip(self, db, filename, display_name):
rid = db.ringtones.add(filename, f"store_{filename[:20]}", display_name)
assert rid is not None
row = db.ringtones.get_by_id(rid)
assert row is not None
assert row["display_name"] == display_name
@given(new_name=st_nasty_text.filter(lambda x: len(x) > 0))
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=30,
)
def test_update_display_name(self, db, new_name):
rid = db.ringtones.add("test.ogg", "store_test.ogg", "Original")
db.ringtones.update(rid, display_name=new_name)
row = db.ringtones.get_by_id(rid)
assert row["display_name"] == new_name
# ===================================================================
# MapDrawingsDAO
# ===================================================================
class TestMapDrawingsDAOFuzzing:
@given(
name=st_nasty_text.filter(lambda x: len(x) > 0),
data=st.one_of(
st_nasty_text,
st.builds(json.dumps, st.dictionaries(st.text(max_size=10), st.text(max_size=50), max_size=10)),
),
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_upsert_drawing_roundtrip(self, db, name, data):
identity = "deadbeef01234567"
db.map_drawings.upsert_drawing(identity, name, data)
drawings = db.map_drawings.get_drawings(identity)
assert any(d["name"] == name for d in drawings)
# ===================================================================
# MessageDAO — folders
# ===================================================================
class TestMessageDAOFoldersFuzzing:
@given(name=st_nasty_text.filter(lambda x: len(x) > 0))
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_create_folder_roundtrip(self, db, name):
db.messages.create_folder(name)
folders = db.messages.get_all_folders()
assert any(f["name"] == name for f in folders)
@given(
original=st_nasty_text.filter(lambda x: len(x) > 0),
renamed=st_nasty_text.filter(lambda x: len(x) > 0),
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=30,
)
def test_rename_folder(self, db, original, renamed):
db.provider.execute("DELETE FROM lxmf_folders")
cursor = db.messages.create_folder(original)
folder_id = cursor.lastrowid
db.messages.rename_folder(folder_id, renamed)
folders = db.messages.get_all_folders()
found = [f for f in folders if f["id"] == folder_id]
assert len(found) == 1
assert found[0]["name"] == renamed
# ===================================================================
# MessageHandler — search
# ===================================================================
class TestMessageHandlerFuzzing:
@given(search=st_search_term)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=80,
)
def test_search_messages_never_crashes(self, handler, search):
results = handler.search_messages("local_hash", search)
assert isinstance(results, list)
@given(search=st_search_term)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=60,
)
def test_get_conversations_search_never_crashes(self, handler, search):
results = handler.get_conversations("local_hash", search=search)
assert isinstance(results, list)
@given(
search=st_search_term,
folder_id=st.one_of(st.none(), st.integers(min_value=0, max_value=1000)),
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_get_conversations_with_filters(self, handler, search, folder_id):
results = handler.get_conversations(
"local_hash",
search=search,
folder_id=folder_id,
filter_unread=True,
)
assert isinstance(results, list)
@given(
dest=st_hex_hash,
after_id=st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000)),
before_id=st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000)),
)
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_get_conversation_messages_never_crashes(self, handler, dest, after_id, before_id):
results = handler.get_conversation_messages(
"local_hash", dest, after_id=after_id, before_id=before_id,
)
assert isinstance(results, list)
# ===================================================================
# _safe_href — URL sanitization
# ===================================================================
class TestSafeHrefFuzzing:
@given(url=st.text(max_size=500))
@settings(deadline=None, max_examples=200)
def test_safe_href_never_returns_dangerous_protocol(self, url):
result = _safe_href(url)
lower = result.strip().lower()
assert not lower.startswith("javascript:")
assert not lower.startswith("data:")
assert not lower.startswith("vbscript:")
assert not lower.startswith("file:")
@pytest.mark.parametrize(
"url",
[
"javascript:alert(1)",
"JAVASCRIPT:alert(1)",
" javascript:alert(1) ",
"jAvAsCrIpT:alert(document.cookie)",
"data:text/html,<script>alert(1)</script>",
"DATA:text/html;base64,PHNjcmlwdD4=",
"vbscript:MsgBox('xss')",
"file:///etc/passwd",
],
)
def test_safe_href_blocks_xss_vectors(self, url):
assert _safe_href(url) == "#"
@pytest.mark.parametrize(
"url",
[
"https://example.com",
"http://example.com",
"/relative/path",
"#anchor",
"mailto:test@example.com",
],
)
def test_safe_href_allows_safe_urls(self, url):
assert _safe_href(url) == url
@given(
url=st.from_regex(
r"(javascript|data|vbscript|file):[A-Za-z0-9()=;,/+]+",
fullmatch=True,
)
)
@settings(deadline=None, max_examples=100)
def test_safe_href_blocks_all_generated_dangerous_urls(self, url):
assert _safe_href(url) == "#"
@given(scheme=st.text(min_size=1, max_size=20).filter(lambda s: ":" not in s and "/" not in s))
@settings(deadline=None, max_examples=60)
def test_safe_href_blocks_unknown_schemes(self, scheme):
"""Any unknown scheme like 'foo:bar' should be blocked."""
url = f"{scheme}:something"
result = _safe_href(url)
lower = url.strip().lower()
if any(lower.startswith(p) for p in ("https://", "http://", "/", "#", "mailto:")):
assert result == url
else:
assert result == "#"
# ===================================================================
# parse_bool_query_param
# ===================================================================
class TestParseBoolQueryParam:
@given(value=st.text(max_size=100))
@settings(deadline=None, max_examples=200)
def test_never_crashes_and_returns_bool(self, value):
result = parse_bool_query_param(value)
assert isinstance(result, bool)
@pytest.mark.parametrize("value,expected", [
("1", True),
("true", True),
("True", True),
("TRUE", True),
("yes", True),
("YES", True),
("on", True),
("ON", True),
("0", False),
("false", False),
("no", False),
("off", False),
("", False),
("maybe", False),
(None, False),
])
def test_known_values(self, value, expected):
assert parse_bool_query_param(value) is expected
# ===================================================================
# message_fields_have_attachments
# ===================================================================
class TestMessageFieldsHaveAttachments:
@given(data=st.text(max_size=500))
@settings(deadline=None, max_examples=100)
def test_never_crashes_on_arbitrary_text(self, data):
result = message_fields_have_attachments(data)
assert isinstance(result, bool)
@given(
obj=st.dictionaries(
st.text(max_size=20),
st.one_of(st.text(max_size=50), st.integers(), st.booleans(), st.none()),
max_size=10,
)
)
@settings(deadline=None, max_examples=80)
def test_never_crashes_on_arbitrary_json_objects(self, obj):
result = message_fields_have_attachments(json.dumps(obj))
assert isinstance(result, bool)
@pytest.mark.parametrize("fields_json,expected", [
(None, False),
("", False),
("{}", False),
("not json", False),
('{"image": "base64data"}', True),
('{"audio": "base64data"}', True),
('{"file_attachments": [{"name": "f.txt"}]}', True),
('{"file_attachments": []}', False),
('{"other_field": "value"}', False),
])
def test_known_cases(self, fields_json, expected):
assert message_fields_have_attachments(fields_json) is expected
@given(
data=st.recursive(
st.one_of(st.text(max_size=20), st.integers(), st.booleans(), st.none()),
lambda children: st.lists(children, max_size=5)
| st.dictionaries(st.text(max_size=10), children, max_size=5),
max_leaves=30,
)
)
@settings(deadline=None, max_examples=60)
def test_deeply_nested_json_never_crashes(self, data):
result = message_fields_have_attachments(json.dumps(data))
assert isinstance(result, bool)
+598
View File
@@ -0,0 +1,598 @@
"""Performance regression tests for the critical hot paths.
Focus areas (user priority):
- NomadNet browser: load announces, search announces, favourites
- Messages: load conversations, search messages, load conversation messages,
upsert messages (drafts)
Metrics collected:
- ops/sec throughput
- p50 / p95 / p99 latency
- Concurrent writer contention
- LIKE-search scaling
All tests have hard assertions so regressions fail CI.
"""
import os
import secrets
import shutil
import statistics
import tempfile
import threading
import time
import unittest
from unittest.mock import MagicMock
from meshchatx.src.backend.announce_manager import AnnounceManager
from meshchatx.src.backend.database import Database
from meshchatx.src.backend.message_handler import MessageHandler
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def percentile(data, pct):
"""Return the pct-th percentile of sorted data."""
if not data:
return 0
s = sorted(data)
k = (len(s) - 1) * (pct / 100)
f = int(k)
c = f + 1
if c >= len(s):
return s[f]
return s[f] + (k - f) * (s[c] - s[f])
def timed_call(fn, *args, **kwargs):
"""Call fn and return (result, duration_ms)."""
t0 = time.perf_counter()
result = fn(*args, **kwargs)
return result, (time.perf_counter() - t0) * 1000
def latency_report(name, durations_ms):
"""Print and return latency stats."""
p50 = percentile(durations_ms, 50)
p95 = percentile(durations_ms, 95)
p99 = percentile(durations_ms, 99)
avg = statistics.mean(durations_ms)
ops = 1000 / avg if avg > 0 else float("inf")
print(
f" {name}: avg={avg:.2f}ms p50={p50:.2f}ms p95={p95:.2f}ms "
f"p99={p99:.2f}ms ops/s={ops:.0f}"
)
return {"avg": avg, "p50": p50, "p95": p95, "p99": p99, "ops": ops}
def make_message(peer_hash, i, content_size=100):
return {
"hash": secrets.token_hex(16),
"source_hash": peer_hash,
"destination_hash": "local_hash_0" * 2,
"peer_hash": peer_hash,
"state": "delivered",
"progress": 1.0,
"is_incoming": i % 2,
"method": "direct",
"delivery_attempts": 1,
"next_delivery_attempt_at": None,
"title": f"Message title {i} " + secrets.token_hex(8),
"content": f"Content body {i} " + "x" * content_size,
"fields": "{}",
"timestamp": time.time() - i,
"rssi": -50,
"snr": 5.0,
"quality": 3,
"is_spam": 0,
"reply_to_hash": None,
}
def make_announce(i):
return {
"destination_hash": secrets.token_hex(16),
"aspect": "lxmf.delivery" if i % 3 != 0 else "lxst.telephony",
"identity_hash": secrets.token_hex(16),
"identity_public_key": "pubkey_" + secrets.token_hex(8),
"app_data": "appdata_" + secrets.token_hex(16),
"rssi": -50 + (i % 30),
"snr": 5.0,
"quality": i % 10,
}
# ---------------------------------------------------------------------------
# Test class
# ---------------------------------------------------------------------------
class TestPerformanceHotPaths(unittest.TestCase):
NUM_MESSAGES = 10_000
NUM_PEERS = 200
NUM_ANNOUNCES = 5_000
NUM_FAVOURITES = 100
NUM_CONTACTS = 50
@classmethod
def setUpClass(cls):
cls.test_dir = tempfile.mkdtemp()
cls.db_path = os.path.join(cls.test_dir, "perf_hotpaths.db")
cls.db = Database(cls.db_path)
cls.db.initialize()
cls.handler = MessageHandler(cls.db)
cls.announce_mgr = AnnounceManager(cls.db)
cls._seed_data()
@classmethod
def tearDownClass(cls):
cls.db.close_all()
shutil.rmtree(cls.test_dir, ignore_errors=True)
@classmethod
def _seed_data(cls):
print(f"\n--- Seeding test data ---")
# Peers
cls.peer_hashes = [secrets.token_hex(16) for _ in range(cls.NUM_PEERS)]
cls.heavy_peer = cls.peer_hashes[0]
# Messages: distribute across peers, heavy_peer gets 2000
print(f" Seeding {cls.NUM_MESSAGES} messages across {cls.NUM_PEERS} peers...")
t0 = time.perf_counter()
with cls.db.provider:
for i in range(cls.NUM_MESSAGES):
if i < 2000:
peer = cls.heavy_peer
else:
peer = cls.peer_hashes[i % cls.NUM_PEERS]
cls.db.messages.upsert_lxmf_message(make_message(peer, i))
print(f" Done in {(time.perf_counter() - t0) * 1000:.0f}ms")
# Announces
print(f" Seeding {cls.NUM_ANNOUNCES} announces...")
cls.announce_hashes = []
t0 = time.perf_counter()
with cls.db.provider:
for i in range(cls.NUM_ANNOUNCES):
data = make_announce(i)
cls.announce_hashes.append(data["destination_hash"])
cls.db.announces.upsert_announce(data)
print(f" Done in {(time.perf_counter() - t0) * 1000:.0f}ms")
# Favourites
print(f" Seeding {cls.NUM_FAVOURITES} favourites...")
with cls.db.provider:
for i in range(cls.NUM_FAVOURITES):
cls.db.announces.upsert_favourite(
cls.announce_hashes[i],
f"Fav Node {i}",
"lxmf.delivery",
)
# Contacts (for JOIN benchmarks)
print(f" Seeding {cls.NUM_CONTACTS} contacts...")
with cls.db.provider:
for i in range(cls.NUM_CONTACTS):
cls.db.contacts.add_contact(
name=f"Contact {i}",
remote_identity_hash=cls.peer_hashes[i % cls.NUM_PEERS],
lxmf_address=cls.peer_hashes[(i + 1) % cls.NUM_PEERS],
)
print("--- Seeding complete ---\n")
# ===================================================================
# ANNOUNCES — load, search, count
# ===================================================================
def test_announce_load_filtered_latency(self):
"""Load announces filtered by aspect with pagination — the NomadNet browser default view."""
print("\n[Announce] Filtered load (aspect + pagination):")
durations = []
offsets = [0, 100, 500, 1000, 2000]
for offset in offsets:
_, ms = timed_call(
self.announce_mgr.get_filtered_announces,
aspect="lxmf.delivery",
limit=50,
offset=offset,
)
durations.append(ms)
stats = latency_report("filtered_load", durations)
self.assertLess(stats["p99"], 100, "Announce filtered load p99 > 100ms")
def test_announce_search_latency(self):
"""Search announces by destination/identity hash substring."""
print("\n[Announce] LIKE search:")
search_terms = [
secrets.token_hex(4),
"abc",
self.announce_hashes[0][:8],
self.announce_hashes[2500][:10],
"nonexistent_term_xyz",
]
durations = []
for term in search_terms:
_, ms = timed_call(
self.announce_mgr.get_filtered_announces,
aspect="lxmf.delivery",
query=term,
limit=50,
offset=0,
)
durations.append(ms)
stats = latency_report("search", durations)
self.assertLess(stats["p95"], 150, "Announce search p95 > 150ms")
def test_announce_search_with_blocked(self):
"""Search with a block-list — simulates real NomadNet browser filtering."""
print("\n[Announce] Search with blocked list:")
blocked = [secrets.token_hex(16) for _ in range(50)]
durations = []
for _ in range(20):
_, ms = timed_call(
self.announce_mgr.get_filtered_announces,
aspect="lxmf.delivery",
query="abc",
blocked_identity_hashes=blocked,
limit=50,
offset=0,
)
durations.append(ms)
stats = latency_report("search+blocked", durations)
self.assertLess(stats["p95"], 200, "Announce search+blocked p95 > 200ms")
def test_announce_count_latency(self):
"""Count announces (used for pagination total)."""
print("\n[Announce] Count:")
durations = []
for _ in range(30):
_, ms = timed_call(
self.announce_mgr.get_filtered_announces_count,
aspect="lxmf.delivery",
)
durations.append(ms)
stats = latency_report("count", durations)
self.assertLess(stats["p95"], 100, "Announce count p95 > 100ms")
# ===================================================================
# FAVOURITES
# ===================================================================
def test_favourites_load_latency(self):
"""Load all favourites — typically displayed in sidebar."""
print("\n[Favourites] Load all:")
durations = []
for _ in range(50):
_, ms = timed_call(self.db.announces.get_favourites, "lxmf.delivery")
durations.append(ms)
stats = latency_report("load_favs", durations)
self.assertLess(stats["p95"], 20, "Favourites load p95 > 20ms")
def test_favourite_upsert_throughput(self):
"""Measure upsert throughput for favourites."""
print("\n[Favourites] Upsert throughput:")
durations = []
for i in range(100):
dest = secrets.token_hex(16)
_, ms = timed_call(
self.db.announces.upsert_favourite,
dest, f"Bench Fav {i}", "lxmf.delivery",
)
durations.append(ms)
stats = latency_report("upsert_fav", durations)
self.assertGreater(stats["ops"], 500, "Favourite upsert < 500 ops/s")
# ===================================================================
# CONVERSATIONS — load, search
# ===================================================================
def test_conversations_load_latency(self):
"""Load conversation list — the main messages sidebar query."""
print("\n[Conversations] Load list (with JOINs):")
durations = []
for _ in range(20):
_, ms = timed_call(
self.handler.get_conversations,
"local_hash",
limit=50,
offset=0,
)
durations.append(ms)
stats = latency_report("load_conversations", durations)
self.assertLess(stats["p95"], 200, "Conversation list p95 > 200ms")
def test_conversations_search_latency(self):
"""Search conversations — LIKE across titles, content, peer hashes."""
print("\n[Conversations] Search:")
terms = ["Message title 5", "Content body", "abc", "zzz_nope", self.heavy_peer[:8]]
durations = []
for term in terms:
_, ms = timed_call(
self.handler.get_conversations,
"local_hash",
search=term,
limit=50,
)
durations.append(ms)
stats = latency_report("search_conversations", durations)
self.assertLess(stats["p95"], 500, "Conversation search p95 > 500ms")
def test_conversations_load_paginated(self):
"""Paginate through conversation list at various offsets."""
print("\n[Conversations] Paginated load:")
durations = []
for offset in [0, 20, 50, 100, 150]:
_, ms = timed_call(
self.handler.get_conversations,
"local_hash",
limit=20,
offset=offset,
)
durations.append(ms)
stats = latency_report("paginated_conversations", durations)
self.assertLess(stats["p95"], 300, "Paginated conversations p95 > 300ms")
# ===================================================================
# MESSAGES — load, search, upsert (drafts)
# ===================================================================
def test_message_load_latency(self):
"""Load messages for a single conversation (heavy peer with 2000 msgs)."""
print("\n[Messages] Load conversation messages:")
durations = []
offsets = [0, 100, 500, 1000, 1900]
for offset in offsets:
result, ms = timed_call(
self.handler.get_conversation_messages,
"local_hash",
self.heavy_peer,
limit=50,
offset=offset,
)
durations.append(ms)
self.assertEqual(len(result), 50)
stats = latency_report("load_messages", durations)
self.assertLess(stats["p99"], 50, "Message load p99 > 50ms")
def test_message_search_latency(self):
"""Search messages across all conversations — the global search."""
print("\n[Messages] Global search:")
terms = [
"Message title 100",
"Content body 5000",
secrets.token_hex(4),
"nonexistent_xyz_123",
]
durations = []
for term in terms:
_, ms = timed_call(
self.handler.search_messages,
"local_hash",
term,
)
durations.append(ms)
stats = latency_report("search_messages", durations)
self.assertLess(stats["p95"], 300, "Message search p95 > 300ms")
def test_message_upsert_throughput(self):
"""Measure message upsert throughput — simulates saving drafts rapidly."""
print("\n[Messages] Upsert throughput (draft saves):")
durations = []
peer = secrets.token_hex(16)
for i in range(200):
msg = make_message(peer, i + 100000)
_, ms = timed_call(self.db.messages.upsert_lxmf_message, msg)
durations.append(ms)
stats = latency_report("upsert_message", durations)
self.assertGreater(stats["ops"], 300, "Message upsert < 300 ops/s")
self.assertLess(stats["p95"], 10, "Message upsert p95 > 10ms")
def test_message_upsert_update_throughput(self):
"""Measure message UPDATE throughput — re-saving existing messages (state changes)."""
print("\n[Messages] Update existing messages:")
peer = secrets.token_hex(16)
msgs = []
for i in range(100):
msg = make_message(peer, i + 200000)
self.db.messages.upsert_lxmf_message(msg)
msgs.append(msg)
durations = []
for msg in msgs:
msg["state"] = "failed"
msg["content"] = "Updated content " + secrets.token_hex(16)
_, ms = timed_call(self.db.messages.upsert_lxmf_message, msg)
durations.append(ms)
stats = latency_report("update_message", durations)
self.assertGreater(stats["ops"], 300, "Message update < 300 ops/s")
# ===================================================================
# CONCURRENT WRITERS — contention stress
# ===================================================================
def test_concurrent_message_writers(self):
"""Multiple threads inserting messages simultaneously."""
print("\n[Concurrency] Message writers:")
num_threads = 8
msgs_per_thread = 100
errors = []
all_durations = []
lock = threading.Lock()
def writer(thread_id):
thread_durations = []
peer = secrets.token_hex(16)
for i in range(msgs_per_thread):
msg = make_message(peer, thread_id * 10000 + i)
try:
_, ms = timed_call(self.db.messages.upsert_lxmf_message, msg)
thread_durations.append(ms)
except Exception as e:
errors.append(str(e))
with lock:
all_durations.extend(thread_durations)
threads = [threading.Thread(target=writer, args=(t,)) for t in range(num_threads)]
t0 = time.perf_counter()
for t in threads:
t.start()
for t in threads:
t.join()
wall_ms = (time.perf_counter() - t0) * 1000
total_ops = num_threads * msgs_per_thread
throughput = total_ops / (wall_ms / 1000)
print(f" Wall time: {wall_ms:.0f}ms for {total_ops} inserts ({throughput:.0f} ops/s)")
stats = latency_report("concurrent_write", all_durations)
self.assertEqual(len(errors), 0, f"Writer errors: {errors[:5]}")
self.assertGreater(throughput, 100, "Concurrent write throughput < 100 ops/s")
def test_concurrent_announce_writers(self):
"""Multiple threads upserting announces simultaneously."""
print("\n[Concurrency] Announce writers:")
num_threads = 6
announces_per_thread = 100
errors = []
all_durations = []
lock = threading.Lock()
def writer(thread_id):
thread_durations = []
for i in range(announces_per_thread):
data = make_announce(thread_id * 10000 + i)
try:
_, ms = timed_call(self.db.announces.upsert_announce, data)
thread_durations.append(ms)
except Exception as e:
errors.append(str(e))
with lock:
all_durations.extend(thread_durations)
threads = [threading.Thread(target=writer, args=(t,)) for t in range(num_threads)]
t0 = time.perf_counter()
for t in threads:
t.start()
for t in threads:
t.join()
wall_ms = (time.perf_counter() - t0) * 1000
total_ops = num_threads * announces_per_thread
throughput = total_ops / (wall_ms / 1000)
print(f" Wall time: {wall_ms:.0f}ms for {total_ops} upserts ({throughput:.0f} ops/s)")
stats = latency_report("concurrent_announce_write", all_durations)
self.assertEqual(len(errors), 0, f"Writer errors: {errors[:5]}")
self.assertGreater(throughput, 100, "Concurrent announce write < 100 ops/s")
def test_concurrent_read_write_contention(self):
"""Writers inserting while readers query — simulates real app usage."""
print("\n[Contention] Mixed read/write:")
num_writers = 4
num_readers = 4
ops_per_thread = 50
write_errors = []
read_errors = []
write_durations = []
read_durations = []
lock = threading.Lock()
def writer(thread_id):
local_durs = []
peer = secrets.token_hex(16)
for i in range(ops_per_thread):
msg = make_message(peer, thread_id * 10000 + i)
try:
_, ms = timed_call(self.db.messages.upsert_lxmf_message, msg)
local_durs.append(ms)
except Exception as e:
write_errors.append(str(e))
with lock:
write_durations.extend(local_durs)
def reader(_thread_id):
local_durs = []
for _ in range(ops_per_thread):
try:
_, ms = timed_call(
self.handler.get_conversations,
"local_hash",
limit=20,
)
local_durs.append(ms)
except Exception as e:
read_errors.append(str(e))
with lock:
read_durations.extend(local_durs)
writers = [threading.Thread(target=writer, args=(t,)) for t in range(num_writers)]
readers = [threading.Thread(target=reader, args=(t,)) for t in range(num_readers)]
t0 = time.perf_counter()
for t in writers + readers:
t.start()
for t in writers + readers:
t.join()
wall_ms = (time.perf_counter() - t0) * 1000
print(f" Wall time: {wall_ms:.0f}ms")
latency_report("contention_writes", write_durations)
latency_report("contention_reads", read_durations)
self.assertEqual(len(write_errors), 0, f"Write errors: {write_errors[:5]}")
self.assertEqual(len(read_errors), 0, f"Read errors: {read_errors[:5]}")
# ===================================================================
# LIKE SEARCH SCALING — how search degrades with data size
# ===================================================================
def test_like_search_scaling(self):
"""Measure how LIKE search scales across different table sizes.
This catches missing FTS indexes or query plan regressions."""
print("\n[Scaling] LIKE search across data sizes:")
# Message search on the existing 10k dataset
_, ms_msg = timed_call(
self.handler.search_messages,
"local_hash",
"Content body",
)
print(f" Message LIKE search ({self.NUM_MESSAGES} rows): {ms_msg:.2f}ms")
self.assertLess(ms_msg, 500, "Message LIKE search > 500ms on 10k rows")
# Announce search on the existing 5k dataset
_, ms_ann = timed_call(
self.announce_mgr.get_filtered_announces,
aspect="lxmf.delivery",
query="abc",
limit=50,
)
print(f" Announce LIKE search ({self.NUM_ANNOUNCES} rows): {ms_ann:.2f}ms")
self.assertLess(ms_ann, 200, "Announce LIKE search > 200ms on 5k rows")
# Contacts search
_, ms_con = timed_call(self.db.contacts.get_contacts, search="Contact")
print(f" Contacts LIKE search ({self.NUM_CONTACTS} rows): {ms_con:.2f}ms")
self.assertLess(ms_con, 50, "Contacts LIKE search > 50ms")
if __name__ == "__main__":
unittest.main()
+392
View File
@@ -0,0 +1,392 @@
"""Tests confirming SQL-injection fixes in the raw-SQL database layer.
Covers:
- ATTACH DATABASE path escaping (single-quote doubling) in LegacyMigrator
- Column-name identifier filtering during legacy migration
- _validate_identifier / _ensure_column rejection in DatabaseSchema
"""
import os
import re
import sqlite3
import tempfile
import pytest
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st
from meshchatx.src.backend.database.legacy_migrator import LegacyMigrator
from meshchatx.src.backend.database.provider import DatabaseProvider
from meshchatx.src.backend.database.schema import DatabaseSchema, _validate_identifier
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def db_env(tmp_path):
"""Provide an initialized DatabaseProvider + Schema in a temp directory."""
db_path = str(tmp_path / "current.db")
provider = DatabaseProvider(db_path)
schema = DatabaseSchema(provider)
schema.initialize()
yield provider, schema, tmp_path
provider.close()
def _make_legacy_db(legacy_dir, identity_hash, tables_sql):
"""Create a legacy database with the given CREATE TABLE + INSERT statements."""
identity_dir = os.path.join(legacy_dir, "identities", identity_hash)
os.makedirs(identity_dir, exist_ok=True)
db_path = os.path.join(identity_dir, "database.db")
conn = sqlite3.connect(db_path)
for sql in tables_sql:
conn.execute(sql)
conn.commit()
conn.close()
return db_path
# ---------------------------------------------------------------------------
# 1. ATTACH DATABASE — single-quote escaping
# ---------------------------------------------------------------------------
class TestAttachDatabasePathEscaping:
def test_path_without_quotes_migrates_normally(self, db_env):
provider, _schema, tmp_path = db_env
legacy_dir = str(tmp_path / "legacy_normal")
identity_hash = "aabbccdd"
_make_legacy_db(
legacy_dir,
identity_hash,
[
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
"INSERT INTO config (key, value) VALUES ('k1', 'v1')",
],
)
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
assert migrator.migrate() is True
row = provider.fetchone("SELECT value FROM config WHERE key = 'k1'")
assert row is not None
assert row["value"] == "v1"
def test_path_with_single_quote_does_not_crash(self, db_env):
"""A path containing a single quote must not cause SQL injection or crash."""
provider, _schema, tmp_path = db_env
quoted_dir = tmp_path / "it's_a_test"
quoted_dir.mkdir(parents=True, exist_ok=True)
legacy_dir = str(quoted_dir)
identity_hash = "aabbccdd"
_make_legacy_db(
legacy_dir,
identity_hash,
[
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
"INSERT INTO config (key, value) VALUES ('q1', 'quoted_val')",
],
)
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
result = migrator.migrate()
assert result is True
row = provider.fetchone("SELECT value FROM config WHERE key = 'q1'")
assert row is not None
assert row["value"] == "quoted_val"
def test_path_with_multiple_quotes(self, db_env):
"""Multiple single quotes in the path are all escaped."""
provider, _schema, tmp_path = db_env
weird_dir = tmp_path / "a'b'c"
weird_dir.mkdir(parents=True, exist_ok=True)
legacy_dir = str(weird_dir)
identity_hash = "11223344"
_make_legacy_db(
legacy_dir,
identity_hash,
[
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
"INSERT INTO config (key, value) VALUES ('mq', 'multi_quote')",
],
)
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
assert migrator.migrate() is True
row = provider.fetchone("SELECT value FROM config WHERE key = 'mq'")
assert row is not None
assert row["value"] == "multi_quote"
def test_path_with_sql_injection_attempt(self, db_env):
"""A path crafted to look like SQL injection is safely escaped."""
provider, _schema, tmp_path = db_env
evil_dir = tmp_path / "'; DROP TABLE config; --"
evil_dir.mkdir(parents=True, exist_ok=True)
legacy_dir = str(evil_dir)
identity_hash = "deadbeef"
_make_legacy_db(
legacy_dir,
identity_hash,
[
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
"INSERT INTO config (key, value) VALUES ('evil', 'nope')",
],
)
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
migrator.migrate()
tables = provider.fetchall(
"SELECT name FROM sqlite_master WHERE type='table' AND name='config'"
)
assert len(tables) > 0, "config table must still exist after injection attempt"
# ---------------------------------------------------------------------------
# 2. Legacy migrator — malicious column names filtered out
# ---------------------------------------------------------------------------
class TestLegacyColumnFiltering:
def test_normal_columns_migrate(self, db_env):
"""Standard column names pass through the identifier filter."""
provider, _schema, tmp_path = db_env
legacy_dir = str(tmp_path / "legacy_cols")
identity_hash = "aabb0011"
_make_legacy_db(
legacy_dir,
identity_hash,
[
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
"INSERT INTO config (key, value) VALUES ('c1', 'ok')",
],
)
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
assert migrator.migrate() is True
row = provider.fetchone("SELECT value FROM config WHERE key = 'c1'")
assert row is not None
def test_malicious_column_name_is_skipped(self, db_env):
"""A column with SQL metacharacters in its name must be silently skipped."""
provider, _schema, tmp_path = db_env
legacy_dir = str(tmp_path / "legacy_evil_col")
identity_hash = "cc00dd00"
identity_dir = os.path.join(legacy_dir, "identities", identity_hash)
os.makedirs(identity_dir, exist_ok=True)
db_path = os.path.join(identity_dir, "database.db")
conn = sqlite3.connect(db_path)
conn.execute(
'CREATE TABLE config (key TEXT UNIQUE, value TEXT, "key; DROP TABLE config" TEXT)'
)
conn.execute(
"INSERT INTO config (key, value) VALUES ('safe', 'data')"
)
conn.commit()
conn.close()
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
assert migrator.migrate() is True
row = provider.fetchone("SELECT value FROM config WHERE key = 'safe'")
assert row is not None
assert row["value"] == "data"
def test_column_with_parentheses_is_skipped(self, db_env):
"""Columns with () in the name are rejected by the identifier regex."""
provider, _schema, tmp_path = db_env
legacy_dir = str(tmp_path / "legacy_parens_col")
identity_hash = "ee00ff00"
identity_dir = os.path.join(legacy_dir, "identities", identity_hash)
os.makedirs(identity_dir, exist_ok=True)
db_path = os.path.join(identity_dir, "database.db")
conn = sqlite3.connect(db_path)
conn.execute(
'CREATE TABLE config (key TEXT UNIQUE, value TEXT, "evil()" TEXT)'
)
conn.execute(
"INSERT INTO config (key, value) VALUES ('p1', 'parens')"
)
conn.commit()
conn.close()
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
assert migrator.migrate() is True
row = provider.fetchone("SELECT value FROM config WHERE key = 'p1'")
assert row is not None
# ---------------------------------------------------------------------------
# 3. _validate_identifier — unit tests
# ---------------------------------------------------------------------------
class TestValidateIdentifier:
@pytest.mark.parametrize(
"name",
[
"config",
"lxmf_messages",
"A",
"_private",
"Column123",
"a_b_c_d",
],
)
def test_valid_identifiers_pass(self, name):
assert _validate_identifier(name) == name
@pytest.mark.parametrize(
"name",
[
"",
"123abc",
"table name",
"col;drop",
"a'b",
'a"b',
"col()",
"x--y",
"a,b",
"hello\nworld",
"tab\there",
"col/**/name",
],
)
def test_invalid_identifiers_raise(self, name):
with pytest.raises(ValueError, match="Invalid SQL"):
_validate_identifier(name)
# ---------------------------------------------------------------------------
# 4. _ensure_column — rejects injection via table/column names
# ---------------------------------------------------------------------------
class TestEnsureColumnInjection:
def test_ensure_column_rejects_malicious_table_name(self, db_env):
_provider, schema, _tmp_path = db_env
with pytest.raises(ValueError, match="Invalid SQL table name"):
schema._ensure_column("config; DROP TABLE config", "new_col", "TEXT")
def test_ensure_column_rejects_malicious_column_name(self, db_env):
_provider, schema, _tmp_path = db_env
with pytest.raises(ValueError, match="Invalid SQL column name"):
schema._ensure_column("config", "col; DROP TABLE config", "TEXT")
def test_ensure_column_works_for_valid_names(self, db_env):
_provider, schema, _tmp_path = db_env
result = schema._ensure_column("config", "test_new_col", "TEXT")
assert result is True
def test_ensure_column_idempotent(self, db_env):
_provider, schema, _tmp_path = db_env
schema._ensure_column("config", "idempotent_col", "TEXT")
result = schema._ensure_column("config", "idempotent_col", "TEXT")
assert result is True
# ---------------------------------------------------------------------------
# 5. Property-based tests — identifier regex
# ---------------------------------------------------------------------------
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
@given(name=st.text(min_size=1, max_size=80))
@settings(deadline=None)
def test_validate_identifier_never_allows_sql_metacharacters(name):
"""No string accepted by _validate_identifier contains SQL metacharacters."""
try:
_validate_identifier(name)
except ValueError:
return
assert ";" not in name
assert "'" not in name
assert '"' not in name
assert "(" not in name
assert ")" not in name
assert " " not in name
assert "-" not in name
assert "/" not in name
assert "\\" not in name
assert "\n" not in name
assert "\r" not in name
assert "\t" not in name
assert "," not in name
assert _IDENTIFIER_RE.match(name)
@given(name=st.from_regex(r"[A-Za-z_][A-Za-z0-9_]{0,30}", fullmatch=True))
@settings(deadline=None)
def test_validate_identifier_accepts_all_valid_identifiers(name):
"""Every string matching the identifier pattern is accepted."""
assert _validate_identifier(name) == name
@given(
name=st.text(
alphabet=st.sampled_from(list(";'\"()- \t\n\r,/*")),
min_size=1,
max_size=30,
)
)
@settings(deadline=None)
def test_validate_identifier_rejects_pure_metacharacter_strings(name):
"""Strings composed entirely of SQL metacharacters are always rejected."""
with pytest.raises(ValueError):
_validate_identifier(name)
# ---------------------------------------------------------------------------
# 6. ATTACH path escaping — property-based
# ---------------------------------------------------------------------------
@given(
path_segment=st.text(
alphabet=st.characters(
whitelist_categories=("L", "N", "P", "S", "Z"),
),
min_size=1,
max_size=60,
)
)
@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
def test_attach_path_escaping_never_breaks_sql(path_segment):
"""The quote-doubling escaping produces a string that SQLite can parse
without breaking out of the literal, regardless of the path content."""
safe = path_segment.replace("'", "''")
sql = f"ATTACH DATABASE '{safe}' AS test_alias"
assert sql.count("ATTACH DATABASE '") == 1
after_open = sql.split("ATTACH DATABASE '", 1)[1]
in_literal = True
i = 0
while i < len(after_open):
if after_open[i] == "'":
if i + 1 < len(after_open) and after_open[i + 1] == "'":
i += 2
continue
else:
in_literal = False
remainder = after_open[i + 1 :]
break
i += 1
if not in_literal:
assert remainder.strip() == "AS test_alias", (
f"Unexpected SQL after literal end: {remainder!r}"
)