mirror of
https://git.quad4.io/RNS-Things/MeshChatX.git
synced 2026-05-10 16:25:22 +00:00
Add SQL injection prevention measures in database migration and schema validation
- Implemented identifier validation regex in legacy migrator and schema. - Enhanced database path handling to escape single quotes during ATTACH DATABASE. - Added tests for SQL injection scenarios, ensuring robustness against malicious inputs. - Introduced fuzz tests for DAO layers to cover edge cases and improve overall test coverage.
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
class LegacyMigrator:
|
||||
@@ -73,7 +76,8 @@ class LegacyMigrator:
|
||||
# Attach the legacy database
|
||||
# We use a randomized alias to avoid collisions
|
||||
alias = f"legacy_{os.urandom(4).hex()}"
|
||||
self.provider.execute(f"ATTACH DATABASE '{legacy_path}' AS {alias}")
|
||||
safe_path = legacy_path.replace("'", "''")
|
||||
self.provider.execute(f"ATTACH DATABASE '{safe_path}' AS {alias}") # noqa: S608
|
||||
|
||||
# Tables that existed in the legacy Peewee version
|
||||
tables_to_migrate = [
|
||||
@@ -121,7 +125,9 @@ class LegacyMigrator:
|
||||
common_columns = [
|
||||
col
|
||||
for col in legacy_columns
|
||||
if col in current_columns and col.lower() != "id"
|
||||
if col in current_columns
|
||||
and col.lower() != "id"
|
||||
and _IDENTIFIER_RE.match(col)
|
||||
]
|
||||
|
||||
if common_columns:
|
||||
|
||||
@@ -1,5 +1,16 @@
|
||||
import re
|
||||
|
||||
from .provider import DatabaseProvider
|
||||
|
||||
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
def _validate_identifier(name: str, label: str = "identifier") -> str:
|
||||
if not _IDENTIFIER_RE.match(name):
|
||||
msg = f"Invalid SQL {label}: {name!r}"
|
||||
raise ValueError(msg)
|
||||
return name
|
||||
|
||||
|
||||
class DatabaseSchema:
|
||||
LATEST_VERSION = 38
|
||||
@@ -28,21 +39,18 @@ class DatabaseSchema:
|
||||
|
||||
def _ensure_column(self, table_name, column_name, column_type):
|
||||
"""Add a column to a table if it doesn't exist."""
|
||||
# First check if it exists using PRAGMA
|
||||
_validate_identifier(table_name, "table name")
|
||||
_validate_identifier(column_name, "column name")
|
||||
|
||||
cursor = self.provider.connection.cursor()
|
||||
try:
|
||||
cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
cursor.execute(f"PRAGMA table_info({table_name})") # noqa: S608
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
if column_name not in columns:
|
||||
try:
|
||||
# SQLite has limitations on ALTER TABLE ADD COLUMN:
|
||||
# 1. Cannot add UNIQUE or PRIMARY KEY columns
|
||||
# 2. Cannot add columns with non-constant defaults (like CURRENT_TIMESTAMP)
|
||||
|
||||
# Strip non-constant defaults if present for the ALTER TABLE statement
|
||||
stmt_type = column_type
|
||||
forbidden_defaults = [
|
||||
"CURRENT_TIMESTAMP",
|
||||
@@ -51,9 +59,6 @@ class DatabaseSchema:
|
||||
]
|
||||
for forbidden in forbidden_defaults:
|
||||
if f"DEFAULT {forbidden}" in stmt_type.upper():
|
||||
# Remove the DEFAULT part for the ALTER statement
|
||||
import re
|
||||
|
||||
stmt_type = re.sub(
|
||||
f"DEFAULT\\s+{forbidden}",
|
||||
"",
|
||||
@@ -61,9 +66,8 @@ class DatabaseSchema:
|
||||
flags=re.IGNORECASE,
|
||||
).strip()
|
||||
|
||||
# Use the connection directly to avoid any middle-ware issues
|
||||
res = self._safe_execute(
|
||||
f"ALTER TABLE {table_name} ADD COLUMN {column_name} {stmt_type}",
|
||||
f"ALTER TABLE {table_name} ADD COLUMN {column_name} {stmt_type}", # noqa: S608
|
||||
)
|
||||
return res is not None
|
||||
except Exception as e:
|
||||
|
||||
@@ -0,0 +1,717 @@
|
||||
"""Property-based and fuzz tests for DAO layers, MessageHandler, and utility functions.
|
||||
|
||||
Covers gaps identified in the existing test suite:
|
||||
- ContactsDAO: add, search, update with adversarial strings
|
||||
- ConfigDAO: get/set/delete round-trip with arbitrary keys and values
|
||||
- MiscDAO: spam keywords, notifications, keyboard shortcuts
|
||||
- TelephoneDAO: add + search with adversarial strings
|
||||
- VoicemailDAO: add + search
|
||||
- DebugLogsDAO: insert + search + count consistency
|
||||
- RingtoneDAO: add + update with adversarial display names
|
||||
- MapDrawingsDAO: upsert + update with arbitrary JSON data
|
||||
- MessageDAO: folder create/rename
|
||||
- MessageHandler: search_messages, get_conversations with adversarial search terms
|
||||
- _safe_href: URL sanitization (XSS vectors)
|
||||
- parse_bool_query_param: arbitrary strings
|
||||
- message_fields_have_attachments: arbitrary JSON
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from hypothesis import HealthCheck, given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from meshchatx.src.backend.database import Database
|
||||
from meshchatx.src.backend.markdown_renderer import _safe_href
|
||||
from meshchatx.src.backend.meshchat_utils import (
|
||||
message_fields_have_attachments,
|
||||
parse_bool_query_param,
|
||||
)
|
||||
from meshchatx.src.backend.message_handler import MessageHandler
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Strings that are valid for most text columns but include adversarial chars
|
||||
st_nasty_text = st.text(
|
||||
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S", "Z", "C")),
|
||||
min_size=0,
|
||||
max_size=300,
|
||||
)
|
||||
|
||||
st_sql_payloads = st.sampled_from([
|
||||
"'; DROP TABLE config; --",
|
||||
"' OR '1'='1",
|
||||
"\" OR \"1\"=\"1",
|
||||
"1; SELECT * FROM sqlite_master",
|
||||
"Robert'); DROP TABLE lxmf_messages;--",
|
||||
"%",
|
||||
"%%",
|
||||
"_",
|
||||
"\x00",
|
||||
"' UNION SELECT key, value FROM config --",
|
||||
"NULL",
|
||||
"1=1",
|
||||
"admin'--",
|
||||
"' AND 1=CONVERT(int,(SELECT TOP 1 table_name FROM information_schema.tables))--",
|
||||
])
|
||||
|
||||
st_search_term = st.one_of(st_nasty_text, st_sql_payloads)
|
||||
|
||||
st_hex_hash = st.from_regex(r"[0-9a-f]{16,64}", fullmatch=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture: initialised in-memory database
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
database = Database(":memory:")
|
||||
database.initialize()
|
||||
yield database
|
||||
database.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler(db):
|
||||
return MessageHandler(db)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# ContactsDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestContactsDAOFuzzing:
|
||||
|
||||
@given(
|
||||
name=st_nasty_text,
|
||||
identity_hash=st_hex_hash,
|
||||
lxmf_addr=st.one_of(st.none(), st_hex_hash),
|
||||
lxst_addr=st.one_of(st.none(), st_hex_hash),
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=60,
|
||||
)
|
||||
def test_add_contact_roundtrip(self, db, name, identity_hash, lxmf_addr, lxst_addr):
|
||||
db.contacts.add_contact(
|
||||
name=name,
|
||||
remote_identity_hash=identity_hash,
|
||||
lxmf_address=lxmf_addr,
|
||||
lxst_address=lxst_addr,
|
||||
)
|
||||
row = db.contacts.get_contact_by_identity_hash(identity_hash)
|
||||
assert row is not None
|
||||
assert row["name"] == name
|
||||
|
||||
@given(search=st_search_term)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=80,
|
||||
)
|
||||
def test_search_contacts_never_crashes(self, db, search):
|
||||
results = db.contacts.get_contacts(search=search)
|
||||
assert isinstance(results, list)
|
||||
count = db.contacts.get_contacts_count(search=search)
|
||||
assert isinstance(count, int)
|
||||
assert count >= 0
|
||||
|
||||
@given(
|
||||
name=st_nasty_text,
|
||||
identity_hash=st_hex_hash,
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_update_contact_never_crashes(self, db, name, identity_hash):
|
||||
db.contacts.add_contact(name="orig", remote_identity_hash=identity_hash)
|
||||
row = db.contacts.get_contact_by_identity_hash(identity_hash)
|
||||
assert row is not None
|
||||
db.contacts.update_contact(row["id"], name=name)
|
||||
updated = db.contacts.get_contact(row["id"])
|
||||
assert updated["name"] == name
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# ConfigDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestConfigDAOFuzzing:
|
||||
|
||||
@given(key=st_nasty_text.filter(lambda x: len(x) > 0), value=st_nasty_text)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=80,
|
||||
)
|
||||
def test_config_set_get_roundtrip(self, db, key, value):
|
||||
db.config.set(key, value)
|
||||
got = db.config.get(key)
|
||||
assert got == value
|
||||
|
||||
@given(key=st_nasty_text.filter(lambda x: len(x) > 0))
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_config_delete_never_crashes(self, db, key):
|
||||
db.config.set(key, "temp")
|
||||
db.config.delete(key)
|
||||
assert db.config.get(key) is None
|
||||
|
||||
@given(key=st_sql_payloads)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
)
|
||||
def test_config_sql_injection_keys(self, db, key):
|
||||
db.config.set(key, "injected?")
|
||||
assert db.config.get(key) == "injected?"
|
||||
tables = db.provider.fetchall(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
)
|
||||
table_names = {r["name"] for r in tables}
|
||||
assert "config" in table_names
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# MiscDAO — spam keywords, notifications, keyboard shortcuts
|
||||
# ===================================================================
|
||||
|
||||
class TestMiscDAOFuzzing:
|
||||
|
||||
@given(keyword=st_nasty_text.filter(lambda x: len(x) > 0))
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=50,
|
||||
)
|
||||
def test_spam_keyword_roundtrip(self, db, keyword):
|
||||
db.misc.add_spam_keyword(keyword)
|
||||
keywords = db.misc.get_spam_keywords()
|
||||
assert any(k["keyword"] == keyword for k in keywords)
|
||||
|
||||
@given(title=st_nasty_text, content=st_nasty_text)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=50,
|
||||
)
|
||||
def test_check_spam_never_crashes(self, db, title, content):
|
||||
result = db.misc.check_spam_keywords(title, content)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
@given(
|
||||
ntype=st.sampled_from(["message", "call", "voicemail", "system"]),
|
||||
remote_hash=st_hex_hash,
|
||||
title=st_nasty_text,
|
||||
content=st_nasty_text,
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_add_notification_never_crashes(self, db, ntype, remote_hash, title, content):
|
||||
db.misc.add_notification(ntype, remote_hash, title, content)
|
||||
notifications = db.misc.get_notifications()
|
||||
assert isinstance(notifications, list)
|
||||
|
||||
@given(
|
||||
action=st_nasty_text.filter(lambda x: len(x) > 0),
|
||||
keys=st_nasty_text,
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_keyboard_shortcut_roundtrip(self, db, action, keys):
|
||||
identity = "testhash"
|
||||
db.misc.upsert_keyboard_shortcut(identity, action, keys)
|
||||
shortcuts = db.misc.get_keyboard_shortcuts(identity)
|
||||
assert any(s["action"] == action and s["keys"] == keys for s in shortcuts)
|
||||
|
||||
@given(
|
||||
query=st_search_term,
|
||||
dest=st.one_of(st.none(), st_hex_hash),
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_archived_pages_search_never_crashes(self, db, query, dest):
|
||||
results = db.misc.get_archived_pages_paginated(
|
||||
destination_hash=dest,
|
||||
query=query if query else None,
|
||||
)
|
||||
assert isinstance(results, list)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TelephoneDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestTelephoneDAOFuzzing:
|
||||
|
||||
@given(
|
||||
name=st_nasty_text,
|
||||
identity_hash=st_hex_hash,
|
||||
status=st.sampled_from(["answered", "missed", "rejected", "busy"]),
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_add_call_history_never_crashes(self, db, name, identity_hash, status):
|
||||
import time
|
||||
|
||||
db.telephone.add_call_history(
|
||||
remote_identity_hash=identity_hash,
|
||||
remote_identity_name=name,
|
||||
is_incoming=True,
|
||||
status=status,
|
||||
duration_seconds=42,
|
||||
timestamp=time.time(),
|
||||
)
|
||||
history = db.telephone.get_call_history()
|
||||
assert isinstance(history, list)
|
||||
assert len(history) > 0
|
||||
|
||||
@given(search=st_search_term)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=60,
|
||||
)
|
||||
def test_call_history_search_never_crashes(self, db, search):
|
||||
results = db.telephone.get_call_history(search=search)
|
||||
assert isinstance(results, list)
|
||||
|
||||
@given(search=st_search_term)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=60,
|
||||
)
|
||||
def test_call_recordings_search_never_crashes(self, db, search):
|
||||
results = db.telephone.get_call_recordings(search=search)
|
||||
assert isinstance(results, list)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# VoicemailDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestVoicemailDAOFuzzing:
|
||||
|
||||
@given(
|
||||
name=st_nasty_text,
|
||||
identity_hash=st_hex_hash,
|
||||
filename=st_nasty_text,
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_add_voicemail_never_crashes(self, db, name, identity_hash, filename):
|
||||
import time
|
||||
|
||||
db.voicemails.add_voicemail(
|
||||
remote_identity_hash=identity_hash,
|
||||
remote_identity_name=name,
|
||||
filename=filename,
|
||||
duration_seconds=10,
|
||||
timestamp=time.time(),
|
||||
)
|
||||
vms = db.voicemails.get_voicemails()
|
||||
assert isinstance(vms, list)
|
||||
|
||||
@given(search=st_search_term)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=60,
|
||||
)
|
||||
def test_voicemail_search_never_crashes(self, db, search):
|
||||
results = db.voicemails.get_voicemails(search=search)
|
||||
assert isinstance(results, list)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# DebugLogsDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestDebugLogsDAOFuzzing:
|
||||
|
||||
@given(
|
||||
level=st.sampled_from(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
|
||||
module=st_nasty_text,
|
||||
message=st_nasty_text,
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=50,
|
||||
)
|
||||
def test_insert_log_never_crashes(self, db, level, module, message):
|
||||
db.debug_logs.insert_log(level, module, message)
|
||||
total = db.debug_logs.get_total_count()
|
||||
assert total > 0
|
||||
|
||||
@given(
|
||||
search=st_search_term,
|
||||
level=st.one_of(st.none(), st.sampled_from(["DEBUG", "INFO", "ERROR"])),
|
||||
module=st.one_of(st.none(), st_nasty_text),
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=60,
|
||||
)
|
||||
def test_search_logs_never_crashes(self, db, search, level, module):
|
||||
results = db.debug_logs.get_logs(search=search, level=level, module=module)
|
||||
assert isinstance(results, list)
|
||||
count = db.debug_logs.get_total_count(search=search, level=level, module=module)
|
||||
assert isinstance(count, int)
|
||||
assert count >= 0
|
||||
|
||||
@given(
|
||||
search=st_search_term,
|
||||
level=st.one_of(st.none(), st.sampled_from(["DEBUG", "INFO", "ERROR"])),
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_count_matches_result_length(self, db, search, level):
|
||||
"""get_total_count must be consistent with len(get_logs) when no limit truncation."""
|
||||
results = db.debug_logs.get_logs(
|
||||
search=search, level=level, limit=100000, offset=0,
|
||||
)
|
||||
count = db.debug_logs.get_total_count(search=search, level=level)
|
||||
assert count == len(results)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# RingtoneDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestRingtoneDAOFuzzing:
|
||||
|
||||
@given(
|
||||
filename=st_nasty_text.filter(lambda x: len(x) > 0),
|
||||
display_name=st_nasty_text.filter(lambda x: len(x) > 0),
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_add_ringtone_roundtrip(self, db, filename, display_name):
|
||||
rid = db.ringtones.add(filename, f"store_{filename[:20]}", display_name)
|
||||
assert rid is not None
|
||||
row = db.ringtones.get_by_id(rid)
|
||||
assert row is not None
|
||||
assert row["display_name"] == display_name
|
||||
|
||||
@given(new_name=st_nasty_text.filter(lambda x: len(x) > 0))
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=30,
|
||||
)
|
||||
def test_update_display_name(self, db, new_name):
|
||||
rid = db.ringtones.add("test.ogg", "store_test.ogg", "Original")
|
||||
db.ringtones.update(rid, display_name=new_name)
|
||||
row = db.ringtones.get_by_id(rid)
|
||||
assert row["display_name"] == new_name
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# MapDrawingsDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestMapDrawingsDAOFuzzing:
|
||||
|
||||
@given(
|
||||
name=st_nasty_text.filter(lambda x: len(x) > 0),
|
||||
data=st.one_of(
|
||||
st_nasty_text,
|
||||
st.builds(json.dumps, st.dictionaries(st.text(max_size=10), st.text(max_size=50), max_size=10)),
|
||||
),
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_upsert_drawing_roundtrip(self, db, name, data):
|
||||
identity = "deadbeef01234567"
|
||||
db.map_drawings.upsert_drawing(identity, name, data)
|
||||
drawings = db.map_drawings.get_drawings(identity)
|
||||
assert any(d["name"] == name for d in drawings)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# MessageDAO — folders
|
||||
# ===================================================================
|
||||
|
||||
class TestMessageDAOFoldersFuzzing:
|
||||
|
||||
@given(name=st_nasty_text.filter(lambda x: len(x) > 0))
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_create_folder_roundtrip(self, db, name):
|
||||
db.messages.create_folder(name)
|
||||
folders = db.messages.get_all_folders()
|
||||
assert any(f["name"] == name for f in folders)
|
||||
|
||||
@given(
|
||||
original=st_nasty_text.filter(lambda x: len(x) > 0),
|
||||
renamed=st_nasty_text.filter(lambda x: len(x) > 0),
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=30,
|
||||
)
|
||||
def test_rename_folder(self, db, original, renamed):
|
||||
db.provider.execute("DELETE FROM lxmf_folders")
|
||||
cursor = db.messages.create_folder(original)
|
||||
folder_id = cursor.lastrowid
|
||||
db.messages.rename_folder(folder_id, renamed)
|
||||
folders = db.messages.get_all_folders()
|
||||
found = [f for f in folders if f["id"] == folder_id]
|
||||
assert len(found) == 1
|
||||
assert found[0]["name"] == renamed
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# MessageHandler — search
|
||||
# ===================================================================
|
||||
|
||||
class TestMessageHandlerFuzzing:
|
||||
|
||||
@given(search=st_search_term)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=80,
|
||||
)
|
||||
def test_search_messages_never_crashes(self, handler, search):
|
||||
results = handler.search_messages("local_hash", search)
|
||||
assert isinstance(results, list)
|
||||
|
||||
@given(search=st_search_term)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=60,
|
||||
)
|
||||
def test_get_conversations_search_never_crashes(self, handler, search):
|
||||
results = handler.get_conversations("local_hash", search=search)
|
||||
assert isinstance(results, list)
|
||||
|
||||
@given(
|
||||
search=st_search_term,
|
||||
folder_id=st.one_of(st.none(), st.integers(min_value=0, max_value=1000)),
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_get_conversations_with_filters(self, handler, search, folder_id):
|
||||
results = handler.get_conversations(
|
||||
"local_hash",
|
||||
search=search,
|
||||
folder_id=folder_id,
|
||||
filter_unread=True,
|
||||
)
|
||||
assert isinstance(results, list)
|
||||
|
||||
@given(
|
||||
dest=st_hex_hash,
|
||||
after_id=st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000)),
|
||||
before_id=st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000)),
|
||||
)
|
||||
@settings(
|
||||
deadline=None,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_get_conversation_messages_never_crashes(self, handler, dest, after_id, before_id):
|
||||
results = handler.get_conversation_messages(
|
||||
"local_hash", dest, after_id=after_id, before_id=before_id,
|
||||
)
|
||||
assert isinstance(results, list)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# _safe_href — URL sanitization
|
||||
# ===================================================================
|
||||
|
||||
class TestSafeHrefFuzzing:
|
||||
|
||||
@given(url=st.text(max_size=500))
|
||||
@settings(deadline=None, max_examples=200)
|
||||
def test_safe_href_never_returns_dangerous_protocol(self, url):
|
||||
result = _safe_href(url)
|
||||
lower = result.strip().lower()
|
||||
assert not lower.startswith("javascript:")
|
||||
assert not lower.startswith("data:")
|
||||
assert not lower.startswith("vbscript:")
|
||||
assert not lower.startswith("file:")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url",
|
||||
[
|
||||
"javascript:alert(1)",
|
||||
"JAVASCRIPT:alert(1)",
|
||||
" javascript:alert(1) ",
|
||||
"jAvAsCrIpT:alert(document.cookie)",
|
||||
"data:text/html,<script>alert(1)</script>",
|
||||
"DATA:text/html;base64,PHNjcmlwdD4=",
|
||||
"vbscript:MsgBox('xss')",
|
||||
"file:///etc/passwd",
|
||||
],
|
||||
)
|
||||
def test_safe_href_blocks_xss_vectors(self, url):
|
||||
assert _safe_href(url) == "#"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url",
|
||||
[
|
||||
"https://example.com",
|
||||
"http://example.com",
|
||||
"/relative/path",
|
||||
"#anchor",
|
||||
"mailto:test@example.com",
|
||||
],
|
||||
)
|
||||
def test_safe_href_allows_safe_urls(self, url):
|
||||
assert _safe_href(url) == url
|
||||
|
||||
@given(
|
||||
url=st.from_regex(
|
||||
r"(javascript|data|vbscript|file):[A-Za-z0-9()=;,/+]+",
|
||||
fullmatch=True,
|
||||
)
|
||||
)
|
||||
@settings(deadline=None, max_examples=100)
|
||||
def test_safe_href_blocks_all_generated_dangerous_urls(self, url):
|
||||
assert _safe_href(url) == "#"
|
||||
|
||||
@given(scheme=st.text(min_size=1, max_size=20).filter(lambda s: ":" not in s and "/" not in s))
|
||||
@settings(deadline=None, max_examples=60)
|
||||
def test_safe_href_blocks_unknown_schemes(self, scheme):
|
||||
"""Any unknown scheme like 'foo:bar' should be blocked."""
|
||||
url = f"{scheme}:something"
|
||||
result = _safe_href(url)
|
||||
lower = url.strip().lower()
|
||||
if any(lower.startswith(p) for p in ("https://", "http://", "/", "#", "mailto:")):
|
||||
assert result == url
|
||||
else:
|
||||
assert result == "#"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# parse_bool_query_param
|
||||
# ===================================================================
|
||||
|
||||
class TestParseBoolQueryParam:
|
||||
|
||||
@given(value=st.text(max_size=100))
|
||||
@settings(deadline=None, max_examples=200)
|
||||
def test_never_crashes_and_returns_bool(self, value):
|
||||
result = parse_bool_query_param(value)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
@pytest.mark.parametrize("value,expected", [
|
||||
("1", True),
|
||||
("true", True),
|
||||
("True", True),
|
||||
("TRUE", True),
|
||||
("yes", True),
|
||||
("YES", True),
|
||||
("on", True),
|
||||
("ON", True),
|
||||
("0", False),
|
||||
("false", False),
|
||||
("no", False),
|
||||
("off", False),
|
||||
("", False),
|
||||
("maybe", False),
|
||||
(None, False),
|
||||
])
|
||||
def test_known_values(self, value, expected):
|
||||
assert parse_bool_query_param(value) is expected
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# message_fields_have_attachments
|
||||
# ===================================================================
|
||||
|
||||
class TestMessageFieldsHaveAttachments:
|
||||
|
||||
@given(data=st.text(max_size=500))
|
||||
@settings(deadline=None, max_examples=100)
|
||||
def test_never_crashes_on_arbitrary_text(self, data):
|
||||
result = message_fields_have_attachments(data)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
@given(
|
||||
obj=st.dictionaries(
|
||||
st.text(max_size=20),
|
||||
st.one_of(st.text(max_size=50), st.integers(), st.booleans(), st.none()),
|
||||
max_size=10,
|
||||
)
|
||||
)
|
||||
@settings(deadline=None, max_examples=80)
|
||||
def test_never_crashes_on_arbitrary_json_objects(self, obj):
|
||||
result = message_fields_have_attachments(json.dumps(obj))
|
||||
assert isinstance(result, bool)
|
||||
|
||||
@pytest.mark.parametrize("fields_json,expected", [
|
||||
(None, False),
|
||||
("", False),
|
||||
("{}", False),
|
||||
("not json", False),
|
||||
('{"image": "base64data"}', True),
|
||||
('{"audio": "base64data"}', True),
|
||||
('{"file_attachments": [{"name": "f.txt"}]}', True),
|
||||
('{"file_attachments": []}', False),
|
||||
('{"other_field": "value"}', False),
|
||||
])
|
||||
def test_known_cases(self, fields_json, expected):
|
||||
assert message_fields_have_attachments(fields_json) is expected
|
||||
|
||||
@given(
|
||||
data=st.recursive(
|
||||
st.one_of(st.text(max_size=20), st.integers(), st.booleans(), st.none()),
|
||||
lambda children: st.lists(children, max_size=5)
|
||||
| st.dictionaries(st.text(max_size=10), children, max_size=5),
|
||||
max_leaves=30,
|
||||
)
|
||||
)
|
||||
@settings(deadline=None, max_examples=60)
|
||||
def test_deeply_nested_json_never_crashes(self, data):
|
||||
result = message_fields_have_attachments(json.dumps(data))
|
||||
assert isinstance(result, bool)
|
||||
@@ -0,0 +1,598 @@
|
||||
"""Performance regression tests for the critical hot paths.
|
||||
|
||||
Focus areas (user priority):
|
||||
- NomadNet browser: load announces, search announces, favourites
|
||||
- Messages: load conversations, search messages, load conversation messages,
|
||||
upsert messages (drafts)
|
||||
|
||||
Metrics collected:
|
||||
- ops/sec throughput
|
||||
- p50 / p95 / p99 latency
|
||||
- Concurrent writer contention
|
||||
- LIKE-search scaling
|
||||
|
||||
All tests have hard assertions so regressions fail CI.
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import statistics
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from meshchatx.src.backend.announce_manager import AnnounceManager
|
||||
from meshchatx.src.backend.database import Database
|
||||
from meshchatx.src.backend.message_handler import MessageHandler
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def percentile(data, pct):
|
||||
"""Return the pct-th percentile of sorted data."""
|
||||
if not data:
|
||||
return 0
|
||||
s = sorted(data)
|
||||
k = (len(s) - 1) * (pct / 100)
|
||||
f = int(k)
|
||||
c = f + 1
|
||||
if c >= len(s):
|
||||
return s[f]
|
||||
return s[f] + (k - f) * (s[c] - s[f])
|
||||
|
||||
|
||||
def timed_call(fn, *args, **kwargs):
|
||||
"""Call fn and return (result, duration_ms)."""
|
||||
t0 = time.perf_counter()
|
||||
result = fn(*args, **kwargs)
|
||||
return result, (time.perf_counter() - t0) * 1000
|
||||
|
||||
|
||||
def latency_report(name, durations_ms):
|
||||
"""Print and return latency stats."""
|
||||
p50 = percentile(durations_ms, 50)
|
||||
p95 = percentile(durations_ms, 95)
|
||||
p99 = percentile(durations_ms, 99)
|
||||
avg = statistics.mean(durations_ms)
|
||||
ops = 1000 / avg if avg > 0 else float("inf")
|
||||
print(
|
||||
f" {name}: avg={avg:.2f}ms p50={p50:.2f}ms p95={p95:.2f}ms "
|
||||
f"p99={p99:.2f}ms ops/s={ops:.0f}"
|
||||
)
|
||||
return {"avg": avg, "p50": p50, "p95": p95, "p99": p99, "ops": ops}
|
||||
|
||||
|
||||
def make_message(peer_hash, i, content_size=100):
|
||||
return {
|
||||
"hash": secrets.token_hex(16),
|
||||
"source_hash": peer_hash,
|
||||
"destination_hash": "local_hash_0" * 2,
|
||||
"peer_hash": peer_hash,
|
||||
"state": "delivered",
|
||||
"progress": 1.0,
|
||||
"is_incoming": i % 2,
|
||||
"method": "direct",
|
||||
"delivery_attempts": 1,
|
||||
"next_delivery_attempt_at": None,
|
||||
"title": f"Message title {i} " + secrets.token_hex(8),
|
||||
"content": f"Content body {i} " + "x" * content_size,
|
||||
"fields": "{}",
|
||||
"timestamp": time.time() - i,
|
||||
"rssi": -50,
|
||||
"snr": 5.0,
|
||||
"quality": 3,
|
||||
"is_spam": 0,
|
||||
"reply_to_hash": None,
|
||||
}
|
||||
|
||||
|
||||
def make_announce(i):
|
||||
return {
|
||||
"destination_hash": secrets.token_hex(16),
|
||||
"aspect": "lxmf.delivery" if i % 3 != 0 else "lxst.telephony",
|
||||
"identity_hash": secrets.token_hex(16),
|
||||
"identity_public_key": "pubkey_" + secrets.token_hex(8),
|
||||
"app_data": "appdata_" + secrets.token_hex(16),
|
||||
"rssi": -50 + (i % 30),
|
||||
"snr": 5.0,
|
||||
"quality": i % 10,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPerformanceHotPaths(unittest.TestCase):
|
||||
|
||||
NUM_MESSAGES = 10_000
|
||||
NUM_PEERS = 200
|
||||
NUM_ANNOUNCES = 5_000
|
||||
NUM_FAVOURITES = 100
|
||||
NUM_CONTACTS = 50
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir = tempfile.mkdtemp()
|
||||
cls.db_path = os.path.join(cls.test_dir, "perf_hotpaths.db")
|
||||
cls.db = Database(cls.db_path)
|
||||
cls.db.initialize()
|
||||
cls.handler = MessageHandler(cls.db)
|
||||
cls.announce_mgr = AnnounceManager(cls.db)
|
||||
|
||||
cls._seed_data()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.db.close_all()
|
||||
shutil.rmtree(cls.test_dir, ignore_errors=True)
|
||||
|
||||
@classmethod
|
||||
def _seed_data(cls):
|
||||
print(f"\n--- Seeding test data ---")
|
||||
|
||||
# Peers
|
||||
cls.peer_hashes = [secrets.token_hex(16) for _ in range(cls.NUM_PEERS)]
|
||||
cls.heavy_peer = cls.peer_hashes[0]
|
||||
|
||||
# Messages: distribute across peers, heavy_peer gets 2000
|
||||
print(f" Seeding {cls.NUM_MESSAGES} messages across {cls.NUM_PEERS} peers...")
|
||||
t0 = time.perf_counter()
|
||||
with cls.db.provider:
|
||||
for i in range(cls.NUM_MESSAGES):
|
||||
if i < 2000:
|
||||
peer = cls.heavy_peer
|
||||
else:
|
||||
peer = cls.peer_hashes[i % cls.NUM_PEERS]
|
||||
cls.db.messages.upsert_lxmf_message(make_message(peer, i))
|
||||
print(f" Done in {(time.perf_counter() - t0) * 1000:.0f}ms")
|
||||
|
||||
# Announces
|
||||
print(f" Seeding {cls.NUM_ANNOUNCES} announces...")
|
||||
cls.announce_hashes = []
|
||||
t0 = time.perf_counter()
|
||||
with cls.db.provider:
|
||||
for i in range(cls.NUM_ANNOUNCES):
|
||||
data = make_announce(i)
|
||||
cls.announce_hashes.append(data["destination_hash"])
|
||||
cls.db.announces.upsert_announce(data)
|
||||
print(f" Done in {(time.perf_counter() - t0) * 1000:.0f}ms")
|
||||
|
||||
# Favourites
|
||||
print(f" Seeding {cls.NUM_FAVOURITES} favourites...")
|
||||
with cls.db.provider:
|
||||
for i in range(cls.NUM_FAVOURITES):
|
||||
cls.db.announces.upsert_favourite(
|
||||
cls.announce_hashes[i],
|
||||
f"Fav Node {i}",
|
||||
"lxmf.delivery",
|
||||
)
|
||||
|
||||
# Contacts (for JOIN benchmarks)
|
||||
print(f" Seeding {cls.NUM_CONTACTS} contacts...")
|
||||
with cls.db.provider:
|
||||
for i in range(cls.NUM_CONTACTS):
|
||||
cls.db.contacts.add_contact(
|
||||
name=f"Contact {i}",
|
||||
remote_identity_hash=cls.peer_hashes[i % cls.NUM_PEERS],
|
||||
lxmf_address=cls.peer_hashes[(i + 1) % cls.NUM_PEERS],
|
||||
)
|
||||
|
||||
print("--- Seeding complete ---\n")
|
||||
|
||||
# ===================================================================
|
||||
# ANNOUNCES — load, search, count
|
||||
# ===================================================================
|
||||
|
||||
def test_announce_load_filtered_latency(self):
|
||||
"""Load announces filtered by aspect with pagination — the NomadNet browser default view."""
|
||||
print("\n[Announce] Filtered load (aspect + pagination):")
|
||||
durations = []
|
||||
offsets = [0, 100, 500, 1000, 2000]
|
||||
for offset in offsets:
|
||||
_, ms = timed_call(
|
||||
self.announce_mgr.get_filtered_announces,
|
||||
aspect="lxmf.delivery",
|
||||
limit=50,
|
||||
offset=offset,
|
||||
)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("filtered_load", durations)
|
||||
self.assertLess(stats["p99"], 100, "Announce filtered load p99 > 100ms")
|
||||
|
||||
def test_announce_search_latency(self):
|
||||
"""Search announces by destination/identity hash substring."""
|
||||
print("\n[Announce] LIKE search:")
|
||||
search_terms = [
|
||||
secrets.token_hex(4),
|
||||
"abc",
|
||||
self.announce_hashes[0][:8],
|
||||
self.announce_hashes[2500][:10],
|
||||
"nonexistent_term_xyz",
|
||||
]
|
||||
durations = []
|
||||
for term in search_terms:
|
||||
_, ms = timed_call(
|
||||
self.announce_mgr.get_filtered_announces,
|
||||
aspect="lxmf.delivery",
|
||||
query=term,
|
||||
limit=50,
|
||||
offset=0,
|
||||
)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("search", durations)
|
||||
self.assertLess(stats["p95"], 150, "Announce search p95 > 150ms")
|
||||
|
||||
def test_announce_search_with_blocked(self):
|
||||
"""Search with a block-list — simulates real NomadNet browser filtering."""
|
||||
print("\n[Announce] Search with blocked list:")
|
||||
blocked = [secrets.token_hex(16) for _ in range(50)]
|
||||
durations = []
|
||||
for _ in range(20):
|
||||
_, ms = timed_call(
|
||||
self.announce_mgr.get_filtered_announces,
|
||||
aspect="lxmf.delivery",
|
||||
query="abc",
|
||||
blocked_identity_hashes=blocked,
|
||||
limit=50,
|
||||
offset=0,
|
||||
)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("search+blocked", durations)
|
||||
self.assertLess(stats["p95"], 200, "Announce search+blocked p95 > 200ms")
|
||||
|
||||
def test_announce_count_latency(self):
|
||||
"""Count announces (used for pagination total)."""
|
||||
print("\n[Announce] Count:")
|
||||
durations = []
|
||||
for _ in range(30):
|
||||
_, ms = timed_call(
|
||||
self.announce_mgr.get_filtered_announces_count,
|
||||
aspect="lxmf.delivery",
|
||||
)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("count", durations)
|
||||
self.assertLess(stats["p95"], 100, "Announce count p95 > 100ms")
|
||||
|
||||
# ===================================================================
|
||||
# FAVOURITES
|
||||
# ===================================================================
|
||||
|
||||
def test_favourites_load_latency(self):
|
||||
"""Load all favourites — typically displayed in sidebar."""
|
||||
print("\n[Favourites] Load all:")
|
||||
durations = []
|
||||
for _ in range(50):
|
||||
_, ms = timed_call(self.db.announces.get_favourites, "lxmf.delivery")
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("load_favs", durations)
|
||||
self.assertLess(stats["p95"], 20, "Favourites load p95 > 20ms")
|
||||
|
||||
def test_favourite_upsert_throughput(self):
|
||||
"""Measure upsert throughput for favourites."""
|
||||
print("\n[Favourites] Upsert throughput:")
|
||||
durations = []
|
||||
for i in range(100):
|
||||
dest = secrets.token_hex(16)
|
||||
_, ms = timed_call(
|
||||
self.db.announces.upsert_favourite,
|
||||
dest, f"Bench Fav {i}", "lxmf.delivery",
|
||||
)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("upsert_fav", durations)
|
||||
self.assertGreater(stats["ops"], 500, "Favourite upsert < 500 ops/s")
|
||||
|
||||
# ===================================================================
|
||||
# CONVERSATIONS — load, search
|
||||
# ===================================================================
|
||||
|
||||
def test_conversations_load_latency(self):
|
||||
"""Load conversation list — the main messages sidebar query."""
|
||||
print("\n[Conversations] Load list (with JOINs):")
|
||||
durations = []
|
||||
for _ in range(20):
|
||||
_, ms = timed_call(
|
||||
self.handler.get_conversations,
|
||||
"local_hash",
|
||||
limit=50,
|
||||
offset=0,
|
||||
)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("load_conversations", durations)
|
||||
self.assertLess(stats["p95"], 200, "Conversation list p95 > 200ms")
|
||||
|
||||
def test_conversations_search_latency(self):
|
||||
"""Search conversations — LIKE across titles, content, peer hashes."""
|
||||
print("\n[Conversations] Search:")
|
||||
terms = ["Message title 5", "Content body", "abc", "zzz_nope", self.heavy_peer[:8]]
|
||||
durations = []
|
||||
for term in terms:
|
||||
_, ms = timed_call(
|
||||
self.handler.get_conversations,
|
||||
"local_hash",
|
||||
search=term,
|
||||
limit=50,
|
||||
)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("search_conversations", durations)
|
||||
self.assertLess(stats["p95"], 500, "Conversation search p95 > 500ms")
|
||||
|
||||
def test_conversations_load_paginated(self):
|
||||
"""Paginate through conversation list at various offsets."""
|
||||
print("\n[Conversations] Paginated load:")
|
||||
durations = []
|
||||
for offset in [0, 20, 50, 100, 150]:
|
||||
_, ms = timed_call(
|
||||
self.handler.get_conversations,
|
||||
"local_hash",
|
||||
limit=20,
|
||||
offset=offset,
|
||||
)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("paginated_conversations", durations)
|
||||
self.assertLess(stats["p95"], 300, "Paginated conversations p95 > 300ms")
|
||||
|
||||
# ===================================================================
|
||||
# MESSAGES — load, search, upsert (drafts)
|
||||
# ===================================================================
|
||||
|
||||
def test_message_load_latency(self):
|
||||
"""Load messages for a single conversation (heavy peer with 2000 msgs)."""
|
||||
print("\n[Messages] Load conversation messages:")
|
||||
durations = []
|
||||
offsets = [0, 100, 500, 1000, 1900]
|
||||
for offset in offsets:
|
||||
result, ms = timed_call(
|
||||
self.handler.get_conversation_messages,
|
||||
"local_hash",
|
||||
self.heavy_peer,
|
||||
limit=50,
|
||||
offset=offset,
|
||||
)
|
||||
durations.append(ms)
|
||||
self.assertEqual(len(result), 50)
|
||||
|
||||
stats = latency_report("load_messages", durations)
|
||||
self.assertLess(stats["p99"], 50, "Message load p99 > 50ms")
|
||||
|
||||
def test_message_search_latency(self):
|
||||
"""Search messages across all conversations — the global search."""
|
||||
print("\n[Messages] Global search:")
|
||||
terms = [
|
||||
"Message title 100",
|
||||
"Content body 5000",
|
||||
secrets.token_hex(4),
|
||||
"nonexistent_xyz_123",
|
||||
]
|
||||
durations = []
|
||||
for term in terms:
|
||||
_, ms = timed_call(
|
||||
self.handler.search_messages,
|
||||
"local_hash",
|
||||
term,
|
||||
)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("search_messages", durations)
|
||||
self.assertLess(stats["p95"], 300, "Message search p95 > 300ms")
|
||||
|
||||
def test_message_upsert_throughput(self):
|
||||
"""Measure message upsert throughput — simulates saving drafts rapidly."""
|
||||
print("\n[Messages] Upsert throughput (draft saves):")
|
||||
durations = []
|
||||
peer = secrets.token_hex(16)
|
||||
for i in range(200):
|
||||
msg = make_message(peer, i + 100000)
|
||||
_, ms = timed_call(self.db.messages.upsert_lxmf_message, msg)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("upsert_message", durations)
|
||||
self.assertGreater(stats["ops"], 300, "Message upsert < 300 ops/s")
|
||||
self.assertLess(stats["p95"], 10, "Message upsert p95 > 10ms")
|
||||
|
||||
def test_message_upsert_update_throughput(self):
|
||||
"""Measure message UPDATE throughput — re-saving existing messages (state changes)."""
|
||||
print("\n[Messages] Update existing messages:")
|
||||
peer = secrets.token_hex(16)
|
||||
msgs = []
|
||||
for i in range(100):
|
||||
msg = make_message(peer, i + 200000)
|
||||
self.db.messages.upsert_lxmf_message(msg)
|
||||
msgs.append(msg)
|
||||
|
||||
durations = []
|
||||
for msg in msgs:
|
||||
msg["state"] = "failed"
|
||||
msg["content"] = "Updated content " + secrets.token_hex(16)
|
||||
_, ms = timed_call(self.db.messages.upsert_lxmf_message, msg)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("update_message", durations)
|
||||
self.assertGreater(stats["ops"], 300, "Message update < 300 ops/s")
|
||||
|
||||
# ===================================================================
|
||||
# CONCURRENT WRITERS — contention stress
|
||||
# ===================================================================
|
||||
|
||||
def test_concurrent_message_writers(self):
|
||||
"""Multiple threads inserting messages simultaneously."""
|
||||
print("\n[Concurrency] Message writers:")
|
||||
num_threads = 8
|
||||
msgs_per_thread = 100
|
||||
errors = []
|
||||
all_durations = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def writer(thread_id):
|
||||
thread_durations = []
|
||||
peer = secrets.token_hex(16)
|
||||
for i in range(msgs_per_thread):
|
||||
msg = make_message(peer, thread_id * 10000 + i)
|
||||
try:
|
||||
_, ms = timed_call(self.db.messages.upsert_lxmf_message, msg)
|
||||
thread_durations.append(ms)
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
with lock:
|
||||
all_durations.extend(thread_durations)
|
||||
|
||||
threads = [threading.Thread(target=writer, args=(t,)) for t in range(num_threads)]
|
||||
t0 = time.perf_counter()
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
wall_ms = (time.perf_counter() - t0) * 1000
|
||||
|
||||
total_ops = num_threads * msgs_per_thread
|
||||
throughput = total_ops / (wall_ms / 1000)
|
||||
print(f" Wall time: {wall_ms:.0f}ms for {total_ops} inserts ({throughput:.0f} ops/s)")
|
||||
stats = latency_report("concurrent_write", all_durations)
|
||||
|
||||
self.assertEqual(len(errors), 0, f"Writer errors: {errors[:5]}")
|
||||
self.assertGreater(throughput, 100, "Concurrent write throughput < 100 ops/s")
|
||||
|
||||
def test_concurrent_announce_writers(self):
|
||||
"""Multiple threads upserting announces simultaneously."""
|
||||
print("\n[Concurrency] Announce writers:")
|
||||
num_threads = 6
|
||||
announces_per_thread = 100
|
||||
errors = []
|
||||
all_durations = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def writer(thread_id):
|
||||
thread_durations = []
|
||||
for i in range(announces_per_thread):
|
||||
data = make_announce(thread_id * 10000 + i)
|
||||
try:
|
||||
_, ms = timed_call(self.db.announces.upsert_announce, data)
|
||||
thread_durations.append(ms)
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
with lock:
|
||||
all_durations.extend(thread_durations)
|
||||
|
||||
threads = [threading.Thread(target=writer, args=(t,)) for t in range(num_threads)]
|
||||
t0 = time.perf_counter()
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
wall_ms = (time.perf_counter() - t0) * 1000
|
||||
|
||||
total_ops = num_threads * announces_per_thread
|
||||
throughput = total_ops / (wall_ms / 1000)
|
||||
print(f" Wall time: {wall_ms:.0f}ms for {total_ops} upserts ({throughput:.0f} ops/s)")
|
||||
stats = latency_report("concurrent_announce_write", all_durations)
|
||||
|
||||
self.assertEqual(len(errors), 0, f"Writer errors: {errors[:5]}")
|
||||
self.assertGreater(throughput, 100, "Concurrent announce write < 100 ops/s")
|
||||
|
||||
def test_concurrent_read_write_contention(self):
|
||||
"""Writers inserting while readers query — simulates real app usage."""
|
||||
print("\n[Contention] Mixed read/write:")
|
||||
num_writers = 4
|
||||
num_readers = 4
|
||||
ops_per_thread = 50
|
||||
write_errors = []
|
||||
read_errors = []
|
||||
write_durations = []
|
||||
read_durations = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def writer(thread_id):
|
||||
local_durs = []
|
||||
peer = secrets.token_hex(16)
|
||||
for i in range(ops_per_thread):
|
||||
msg = make_message(peer, thread_id * 10000 + i)
|
||||
try:
|
||||
_, ms = timed_call(self.db.messages.upsert_lxmf_message, msg)
|
||||
local_durs.append(ms)
|
||||
except Exception as e:
|
||||
write_errors.append(str(e))
|
||||
with lock:
|
||||
write_durations.extend(local_durs)
|
||||
|
||||
def reader(_thread_id):
|
||||
local_durs = []
|
||||
for _ in range(ops_per_thread):
|
||||
try:
|
||||
_, ms = timed_call(
|
||||
self.handler.get_conversations,
|
||||
"local_hash",
|
||||
limit=20,
|
||||
)
|
||||
local_durs.append(ms)
|
||||
except Exception as e:
|
||||
read_errors.append(str(e))
|
||||
with lock:
|
||||
read_durations.extend(local_durs)
|
||||
|
||||
writers = [threading.Thread(target=writer, args=(t,)) for t in range(num_writers)]
|
||||
readers = [threading.Thread(target=reader, args=(t,)) for t in range(num_readers)]
|
||||
|
||||
t0 = time.perf_counter()
|
||||
for t in writers + readers:
|
||||
t.start()
|
||||
for t in writers + readers:
|
||||
t.join()
|
||||
wall_ms = (time.perf_counter() - t0) * 1000
|
||||
|
||||
print(f" Wall time: {wall_ms:.0f}ms")
|
||||
latency_report("contention_writes", write_durations)
|
||||
latency_report("contention_reads", read_durations)
|
||||
|
||||
self.assertEqual(len(write_errors), 0, f"Write errors: {write_errors[:5]}")
|
||||
self.assertEqual(len(read_errors), 0, f"Read errors: {read_errors[:5]}")
|
||||
|
||||
# ===================================================================
|
||||
# LIKE SEARCH SCALING — how search degrades with data size
|
||||
# ===================================================================
|
||||
|
||||
def test_like_search_scaling(self):
|
||||
"""Measure how LIKE search scales across different table sizes.
|
||||
This catches missing FTS indexes or query plan regressions."""
|
||||
print("\n[Scaling] LIKE search across data sizes:")
|
||||
|
||||
# Message search on the existing 10k dataset
|
||||
_, ms_msg = timed_call(
|
||||
self.handler.search_messages,
|
||||
"local_hash",
|
||||
"Content body",
|
||||
)
|
||||
print(f" Message LIKE search ({self.NUM_MESSAGES} rows): {ms_msg:.2f}ms")
|
||||
self.assertLess(ms_msg, 500, "Message LIKE search > 500ms on 10k rows")
|
||||
|
||||
# Announce search on the existing 5k dataset
|
||||
_, ms_ann = timed_call(
|
||||
self.announce_mgr.get_filtered_announces,
|
||||
aspect="lxmf.delivery",
|
||||
query="abc",
|
||||
limit=50,
|
||||
)
|
||||
print(f" Announce LIKE search ({self.NUM_ANNOUNCES} rows): {ms_ann:.2f}ms")
|
||||
self.assertLess(ms_ann, 200, "Announce LIKE search > 200ms on 5k rows")
|
||||
|
||||
# Contacts search
|
||||
_, ms_con = timed_call(self.db.contacts.get_contacts, search="Contact")
|
||||
print(f" Contacts LIKE search ({self.NUM_CONTACTS} rows): {ms_con:.2f}ms")
|
||||
self.assertLess(ms_con, 50, "Contacts LIKE search > 50ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,392 @@
|
||||
"""Tests confirming SQL-injection fixes in the raw-SQL database layer.
|
||||
|
||||
Covers:
|
||||
- ATTACH DATABASE path escaping (single-quote doubling) in LegacyMigrator
|
||||
- Column-name identifier filtering during legacy migration
|
||||
- _validate_identifier / _ensure_column rejection in DatabaseSchema
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from hypothesis import HealthCheck, given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from meshchatx.src.backend.database.legacy_migrator import LegacyMigrator
|
||||
from meshchatx.src.backend.database.provider import DatabaseProvider
|
||||
from meshchatx.src.backend.database.schema import DatabaseSchema, _validate_identifier
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def db_env(tmp_path):
|
||||
"""Provide an initialized DatabaseProvider + Schema in a temp directory."""
|
||||
db_path = str(tmp_path / "current.db")
|
||||
provider = DatabaseProvider(db_path)
|
||||
schema = DatabaseSchema(provider)
|
||||
schema.initialize()
|
||||
yield provider, schema, tmp_path
|
||||
provider.close()
|
||||
|
||||
|
||||
def _make_legacy_db(legacy_dir, identity_hash, tables_sql):
|
||||
"""Create a legacy database with the given CREATE TABLE + INSERT statements."""
|
||||
identity_dir = os.path.join(legacy_dir, "identities", identity_hash)
|
||||
os.makedirs(identity_dir, exist_ok=True)
|
||||
db_path = os.path.join(identity_dir, "database.db")
|
||||
conn = sqlite3.connect(db_path)
|
||||
for sql in tables_sql:
|
||||
conn.execute(sql)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. ATTACH DATABASE — single-quote escaping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAttachDatabasePathEscaping:
|
||||
|
||||
def test_path_without_quotes_migrates_normally(self, db_env):
|
||||
provider, _schema, tmp_path = db_env
|
||||
legacy_dir = str(tmp_path / "legacy_normal")
|
||||
identity_hash = "aabbccdd"
|
||||
_make_legacy_db(
|
||||
legacy_dir,
|
||||
identity_hash,
|
||||
[
|
||||
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
|
||||
"INSERT INTO config (key, value) VALUES ('k1', 'v1')",
|
||||
],
|
||||
)
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
assert migrator.migrate() is True
|
||||
|
||||
row = provider.fetchone("SELECT value FROM config WHERE key = 'k1'")
|
||||
assert row is not None
|
||||
assert row["value"] == "v1"
|
||||
|
||||
def test_path_with_single_quote_does_not_crash(self, db_env):
|
||||
"""A path containing a single quote must not cause SQL injection or crash."""
|
||||
provider, _schema, tmp_path = db_env
|
||||
|
||||
quoted_dir = tmp_path / "it's_a_test"
|
||||
quoted_dir.mkdir(parents=True, exist_ok=True)
|
||||
legacy_dir = str(quoted_dir)
|
||||
identity_hash = "aabbccdd"
|
||||
_make_legacy_db(
|
||||
legacy_dir,
|
||||
identity_hash,
|
||||
[
|
||||
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
|
||||
"INSERT INTO config (key, value) VALUES ('q1', 'quoted_val')",
|
||||
],
|
||||
)
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
result = migrator.migrate()
|
||||
assert result is True
|
||||
|
||||
row = provider.fetchone("SELECT value FROM config WHERE key = 'q1'")
|
||||
assert row is not None
|
||||
assert row["value"] == "quoted_val"
|
||||
|
||||
def test_path_with_multiple_quotes(self, db_env):
|
||||
"""Multiple single quotes in the path are all escaped."""
|
||||
provider, _schema, tmp_path = db_env
|
||||
|
||||
weird_dir = tmp_path / "a'b'c"
|
||||
weird_dir.mkdir(parents=True, exist_ok=True)
|
||||
legacy_dir = str(weird_dir)
|
||||
identity_hash = "11223344"
|
||||
_make_legacy_db(
|
||||
legacy_dir,
|
||||
identity_hash,
|
||||
[
|
||||
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
|
||||
"INSERT INTO config (key, value) VALUES ('mq', 'multi_quote')",
|
||||
],
|
||||
)
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
assert migrator.migrate() is True
|
||||
|
||||
row = provider.fetchone("SELECT value FROM config WHERE key = 'mq'")
|
||||
assert row is not None
|
||||
assert row["value"] == "multi_quote"
|
||||
|
||||
def test_path_with_sql_injection_attempt(self, db_env):
|
||||
"""A path crafted to look like SQL injection is safely escaped."""
|
||||
provider, _schema, tmp_path = db_env
|
||||
|
||||
evil_dir = tmp_path / "'; DROP TABLE config; --"
|
||||
evil_dir.mkdir(parents=True, exist_ok=True)
|
||||
legacy_dir = str(evil_dir)
|
||||
identity_hash = "deadbeef"
|
||||
_make_legacy_db(
|
||||
legacy_dir,
|
||||
identity_hash,
|
||||
[
|
||||
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
|
||||
"INSERT INTO config (key, value) VALUES ('evil', 'nope')",
|
||||
],
|
||||
)
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
migrator.migrate()
|
||||
|
||||
tables = provider.fetchall(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='config'"
|
||||
)
|
||||
assert len(tables) > 0, "config table must still exist after injection attempt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Legacy migrator — malicious column names filtered out
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLegacyColumnFiltering:
|
||||
|
||||
def test_normal_columns_migrate(self, db_env):
|
||||
"""Standard column names pass through the identifier filter."""
|
||||
provider, _schema, tmp_path = db_env
|
||||
legacy_dir = str(tmp_path / "legacy_cols")
|
||||
identity_hash = "aabb0011"
|
||||
_make_legacy_db(
|
||||
legacy_dir,
|
||||
identity_hash,
|
||||
[
|
||||
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
|
||||
"INSERT INTO config (key, value) VALUES ('c1', 'ok')",
|
||||
],
|
||||
)
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
assert migrator.migrate() is True
|
||||
|
||||
row = provider.fetchone("SELECT value FROM config WHERE key = 'c1'")
|
||||
assert row is not None
|
||||
|
||||
def test_malicious_column_name_is_skipped(self, db_env):
|
||||
"""A column with SQL metacharacters in its name must be silently skipped."""
|
||||
provider, _schema, tmp_path = db_env
|
||||
legacy_dir = str(tmp_path / "legacy_evil_col")
|
||||
identity_hash = "cc00dd00"
|
||||
|
||||
identity_dir = os.path.join(legacy_dir, "identities", identity_hash)
|
||||
os.makedirs(identity_dir, exist_ok=True)
|
||||
db_path = os.path.join(identity_dir, "database.db")
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(
|
||||
'CREATE TABLE config (key TEXT UNIQUE, value TEXT, "key; DROP TABLE config" TEXT)'
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO config (key, value) VALUES ('safe', 'data')"
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
assert migrator.migrate() is True
|
||||
|
||||
row = provider.fetchone("SELECT value FROM config WHERE key = 'safe'")
|
||||
assert row is not None
|
||||
assert row["value"] == "data"
|
||||
|
||||
def test_column_with_parentheses_is_skipped(self, db_env):
|
||||
"""Columns with () in the name are rejected by the identifier regex."""
|
||||
provider, _schema, tmp_path = db_env
|
||||
legacy_dir = str(tmp_path / "legacy_parens_col")
|
||||
identity_hash = "ee00ff00"
|
||||
|
||||
identity_dir = os.path.join(legacy_dir, "identities", identity_hash)
|
||||
os.makedirs(identity_dir, exist_ok=True)
|
||||
db_path = os.path.join(identity_dir, "database.db")
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(
|
||||
'CREATE TABLE config (key TEXT UNIQUE, value TEXT, "evil()" TEXT)'
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO config (key, value) VALUES ('p1', 'parens')"
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
assert migrator.migrate() is True
|
||||
|
||||
row = provider.fetchone("SELECT value FROM config WHERE key = 'p1'")
|
||||
assert row is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. _validate_identifier — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateIdentifier:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
[
|
||||
"config",
|
||||
"lxmf_messages",
|
||||
"A",
|
||||
"_private",
|
||||
"Column123",
|
||||
"a_b_c_d",
|
||||
],
|
||||
)
|
||||
def test_valid_identifiers_pass(self, name):
|
||||
assert _validate_identifier(name) == name
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
[
|
||||
"",
|
||||
"123abc",
|
||||
"table name",
|
||||
"col;drop",
|
||||
"a'b",
|
||||
'a"b',
|
||||
"col()",
|
||||
"x--y",
|
||||
"a,b",
|
||||
"hello\nworld",
|
||||
"tab\there",
|
||||
"col/**/name",
|
||||
],
|
||||
)
|
||||
def test_invalid_identifiers_raise(self, name):
|
||||
with pytest.raises(ValueError, match="Invalid SQL"):
|
||||
_validate_identifier(name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. _ensure_column — rejects injection via table/column names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnsureColumnInjection:
|
||||
|
||||
def test_ensure_column_rejects_malicious_table_name(self, db_env):
|
||||
_provider, schema, _tmp_path = db_env
|
||||
with pytest.raises(ValueError, match="Invalid SQL table name"):
|
||||
schema._ensure_column("config; DROP TABLE config", "new_col", "TEXT")
|
||||
|
||||
def test_ensure_column_rejects_malicious_column_name(self, db_env):
|
||||
_provider, schema, _tmp_path = db_env
|
||||
with pytest.raises(ValueError, match="Invalid SQL column name"):
|
||||
schema._ensure_column("config", "col; DROP TABLE config", "TEXT")
|
||||
|
||||
def test_ensure_column_works_for_valid_names(self, db_env):
|
||||
_provider, schema, _tmp_path = db_env
|
||||
result = schema._ensure_column("config", "test_new_col", "TEXT")
|
||||
assert result is True
|
||||
|
||||
def test_ensure_column_idempotent(self, db_env):
|
||||
_provider, schema, _tmp_path = db_env
|
||||
schema._ensure_column("config", "idempotent_col", "TEXT")
|
||||
result = schema._ensure_column("config", "idempotent_col", "TEXT")
|
||||
assert result is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Property-based tests — identifier regex
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
@given(name=st.text(min_size=1, max_size=80))
|
||||
@settings(deadline=None)
|
||||
def test_validate_identifier_never_allows_sql_metacharacters(name):
|
||||
"""No string accepted by _validate_identifier contains SQL metacharacters."""
|
||||
try:
|
||||
_validate_identifier(name)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
assert ";" not in name
|
||||
assert "'" not in name
|
||||
assert '"' not in name
|
||||
assert "(" not in name
|
||||
assert ")" not in name
|
||||
assert " " not in name
|
||||
assert "-" not in name
|
||||
assert "/" not in name
|
||||
assert "\\" not in name
|
||||
assert "\n" not in name
|
||||
assert "\r" not in name
|
||||
assert "\t" not in name
|
||||
assert "," not in name
|
||||
assert _IDENTIFIER_RE.match(name)
|
||||
|
||||
|
||||
@given(name=st.from_regex(r"[A-Za-z_][A-Za-z0-9_]{0,30}", fullmatch=True))
|
||||
@settings(deadline=None)
|
||||
def test_validate_identifier_accepts_all_valid_identifiers(name):
|
||||
"""Every string matching the identifier pattern is accepted."""
|
||||
assert _validate_identifier(name) == name
|
||||
|
||||
|
||||
@given(
|
||||
name=st.text(
|
||||
alphabet=st.sampled_from(list(";'\"()- \t\n\r,/*")),
|
||||
min_size=1,
|
||||
max_size=30,
|
||||
)
|
||||
)
|
||||
@settings(deadline=None)
|
||||
def test_validate_identifier_rejects_pure_metacharacter_strings(name):
|
||||
"""Strings composed entirely of SQL metacharacters are always rejected."""
|
||||
with pytest.raises(ValueError):
|
||||
_validate_identifier(name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. ATTACH path escaping — property-based
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@given(
|
||||
path_segment=st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "P", "S", "Z"),
|
||||
),
|
||||
min_size=1,
|
||||
max_size=60,
|
||||
)
|
||||
)
|
||||
@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
|
||||
def test_attach_path_escaping_never_breaks_sql(path_segment):
|
||||
"""The quote-doubling escaping produces a string that SQLite can parse
|
||||
without breaking out of the literal, regardless of the path content."""
|
||||
safe = path_segment.replace("'", "''")
|
||||
sql = f"ATTACH DATABASE '{safe}' AS test_alias"
|
||||
|
||||
assert sql.count("ATTACH DATABASE '") == 1
|
||||
|
||||
after_open = sql.split("ATTACH DATABASE '", 1)[1]
|
||||
in_literal = True
|
||||
i = 0
|
||||
while i < len(after_open):
|
||||
if after_open[i] == "'":
|
||||
if i + 1 < len(after_open) and after_open[i + 1] == "'":
|
||||
i += 2
|
||||
continue
|
||||
else:
|
||||
in_literal = False
|
||||
remainder = after_open[i + 1 :]
|
||||
break
|
||||
i += 1
|
||||
|
||||
if not in_literal:
|
||||
assert remainder.strip() == "AS test_alias", (
|
||||
f"Unexpected SQL after literal end: {remainder!r}"
|
||||
)
|
||||
Reference in New Issue
Block a user