Improve database performance and SQL handling

- Introduced SQLite pragma tuning in the Database initialization for improved performance.
- Wrapped multiple database operations in transactions to optimize batch processing in MessageDAO.
- Updated DatabaseSchema to version 39, adding new indexes for better query performance.
- Improved test coverage for batch operations and SQL injection scenarios in the DAO layer.
This commit is contained in:
Sudo-Ivan
2026-03-06 03:26:36 -06:00
parent 9d7ae2017b
commit bc8969ab16
6 changed files with 354 additions and 119 deletions
+5 -1
View File
@@ -43,6 +43,7 @@ class Database:
self.debug_logs = DebugLogsDAO(self.provider)
def initialize(self):
self._tune_sqlite_pragmas()
self.schema.initialize()
def migrate_from_legacy(self, reticulum_config_dir, identity_hash_hex):
@@ -60,9 +61,12 @@ class Database:
def _tune_sqlite_pragmas(self):
try:
self.execute_sql("PRAGMA journal_mode=WAL")
self.execute_sql("PRAGMA synchronous=NORMAL")
self.execute_sql("PRAGMA wal_autocheckpoint=1000")
self.execute_sql("PRAGMA temp_store=MEMORY")
self.execute_sql("PRAGMA journal_mode=WAL")
self.execute_sql("PRAGMA cache_size=-8000") # 8 MB
self.execute_sql("PRAGMA mmap_size=67108864") # 64 MB
except Exception as exc:
print(f"SQLite pragma setup failed: {exc}")
+27 -24
View File
@@ -135,17 +135,18 @@ class MessageDAO:
if not destination_hashes:
return
now = datetime.now(UTC).isoformat()
for destination_hash in destination_hashes:
self.provider.execute(
"""
INSERT INTO lxmf_conversation_read_state (destination_hash, last_read_at, created_at, updated_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(destination_hash) DO UPDATE SET
last_read_at = EXCLUDED.last_read_at,
updated_at = EXCLUDED.updated_at
""",
(destination_hash, now, now, now),
)
with self.provider:
for destination_hash in destination_hashes:
self.provider.execute(
"""
INSERT INTO lxmf_conversation_read_state (destination_hash, last_read_at, created_at, updated_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(destination_hash) DO UPDATE SET
last_read_at = EXCLUDED.last_read_at,
updated_at = EXCLUDED.updated_at
""",
(destination_hash, now, now, now),
)
def is_conversation_unread(self, destination_hash):
row = self.provider.fetchone(
@@ -310,17 +311,18 @@ class MessageDAO:
def mark_all_notifications_as_viewed(self, destination_hashes=None):
now = datetime.now(UTC).isoformat()
if destination_hashes:
for destination_hash in destination_hashes:
self.provider.execute(
"""
INSERT INTO notification_viewed_state (destination_hash, last_viewed_at, created_at, updated_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(destination_hash) DO UPDATE SET
last_viewed_at = EXCLUDED.last_viewed_at,
updated_at = EXCLUDED.updated_at
""",
(destination_hash, now, now, now),
)
with self.provider:
for destination_hash in destination_hashes:
self.provider.execute(
"""
INSERT INTO notification_viewed_state (destination_hash, last_viewed_at, created_at, updated_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(destination_hash) DO UPDATE SET
last_viewed_at = EXCLUDED.last_viewed_at,
updated_at = EXCLUDED.updated_at
""",
(destination_hash, now, now, now),
)
else:
# mark all conversations as viewed
self.provider.execute(
@@ -397,8 +399,9 @@ class MessageDAO:
)
def move_conversations_to_folder(self, peer_hashes, folder_id):
for peer_hash in peer_hashes:
self.move_conversation_to_folder(peer_hash, folder_id)
with self.provider:
for peer_hash in peer_hashes:
self.move_conversation_to_folder(peer_hash, folder_id)
def get_all_conversation_folders(self):
return self.provider.fetchall("SELECT * FROM lxmf_conversation_folders")
+34 -1
View File
@@ -13,7 +13,7 @@ def _validate_identifier(name: str, label: str = "identifier") -> str:
class DatabaseSchema:
LATEST_VERSION = 38
LATEST_VERSION = 39
def __init__(self, provider: DatabaseProvider):
self.provider = provider
@@ -1013,6 +1013,39 @@ class DatabaseSchema:
"CREATE INDEX IF NOT EXISTS idx_lxmf_messages_reply_to_hash ON lxmf_messages(reply_to_hash)",
)
if current_version < 39:
# Indexes for contacts JOIN columns (used in message_handler.get_conversations
# and announce_manager.get_filtered_announces OR-based JOINs)
self._safe_execute(
"CREATE INDEX IF NOT EXISTS idx_contacts_lxmf_address ON contacts(lxmf_address)",
)
self._safe_execute(
"CREATE INDEX IF NOT EXISTS idx_contacts_lxst_address ON contacts(lxst_address)",
)
# Notifications: filter by is_viewed
self._safe_execute(
"CREATE INDEX IF NOT EXISTS idx_notifications_is_viewed ON notifications(is_viewed)",
)
# Map drawings: lookup by identity_hash
self._safe_execute(
"CREATE INDEX IF NOT EXISTS idx_map_drawings_identity_hash ON map_drawings(identity_hash)",
)
self._safe_execute(
"CREATE INDEX IF NOT EXISTS idx_map_drawings_identity_name ON map_drawings(identity_hash, name)",
)
# Voicemails: filter by is_read
self._safe_execute(
"CREATE INDEX IF NOT EXISTS idx_voicemails_is_read ON voicemails(is_read)",
)
# Archived pages: ORDER BY created_at for cleanup queries
self._safe_execute(
"CREATE INDEX IF NOT EXISTS idx_archived_pages_created_at ON archived_pages(created_at)",
)
# Conversation message state+peer: for failed_count subquery
self._safe_execute(
"CREATE INDEX IF NOT EXISTS idx_lxmf_messages_state_peer ON lxmf_messages(state, peer_hash)",
)
# Update version in config
self._safe_execute(
"""
+98 -66
View File
@@ -42,22 +42,24 @@ st_nasty_text = st.text(
max_size=300,
)
st_sql_payloads = st.sampled_from([
"'; DROP TABLE config; --",
"' OR '1'='1",
"\" OR \"1\"=\"1",
"1; SELECT * FROM sqlite_master",
"Robert'); DROP TABLE lxmf_messages;--",
"%",
"%%",
"_",
"\x00",
"' UNION SELECT key, value FROM config --",
"NULL",
"1=1",
"admin'--",
"' AND 1=CONVERT(int,(SELECT TOP 1 table_name FROM information_schema.tables))--",
])
st_sql_payloads = st.sampled_from(
[
"'; DROP TABLE config; --",
"' OR '1'='1",
'" OR "1"="1',
"1; SELECT * FROM sqlite_master",
"Robert'); DROP TABLE lxmf_messages;--",
"%",
"%%",
"_",
"\x00",
"' UNION SELECT key, value FROM config --",
"NULL",
"1=1",
"admin'--",
"' AND 1=CONVERT(int,(SELECT TOP 1 table_name FROM information_schema.tables))--",
]
)
st_search_term = st.one_of(st_nasty_text, st_sql_payloads)
@@ -68,6 +70,7 @@ st_hex_hash = st.from_regex(r"[0-9a-f]{16,64}", fullmatch=True)
# Fixture: initialised in-memory database
# ---------------------------------------------------------------------------
@pytest.fixture
def db():
database = Database(":memory:")
@@ -85,8 +88,8 @@ def handler(db):
# ContactsDAO
# ===================================================================
class TestContactsDAOFuzzing:
class TestContactsDAOFuzzing:
@given(
name=st_nasty_text,
identity_hash=st_hex_hash,
@@ -144,8 +147,8 @@ class TestContactsDAOFuzzing:
# ConfigDAO
# ===================================================================
class TestConfigDAOFuzzing:
class TestConfigDAOFuzzing:
@given(key=st_nasty_text.filter(lambda x: len(x) > 0), value=st_nasty_text)
@settings(
deadline=None,
@@ -187,8 +190,8 @@ class TestConfigDAOFuzzing:
# MiscDAO — spam keywords, notifications, keyboard shortcuts
# ===================================================================
class TestMiscDAOFuzzing:
class TestMiscDAOFuzzing:
@given(keyword=st_nasty_text.filter(lambda x: len(x) > 0))
@settings(
deadline=None,
@@ -221,7 +224,9 @@ class TestMiscDAOFuzzing:
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_add_notification_never_crashes(self, db, ntype, remote_hash, title, content):
def test_add_notification_never_crashes(
self, db, ntype, remote_hash, title, content
):
db.misc.add_notification(ntype, remote_hash, title, content)
notifications = db.misc.get_notifications()
assert isinstance(notifications, list)
@@ -262,8 +267,8 @@ class TestMiscDAOFuzzing:
# TelephoneDAO
# ===================================================================
class TestTelephoneDAOFuzzing:
class TestTelephoneDAOFuzzing:
@given(
name=st_nasty_text,
identity_hash=st_hex_hash,
@@ -314,8 +319,8 @@ class TestTelephoneDAOFuzzing:
# VoicemailDAO
# ===================================================================
class TestVoicemailDAOFuzzing:
class TestVoicemailDAOFuzzing:
@given(
name=st_nasty_text,
identity_hash=st_hex_hash,
@@ -354,8 +359,8 @@ class TestVoicemailDAOFuzzing:
# DebugLogsDAO
# ===================================================================
class TestDebugLogsDAOFuzzing:
class TestDebugLogsDAOFuzzing:
@given(
level=st.sampled_from(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
module=st_nasty_text,
@@ -400,7 +405,10 @@ class TestDebugLogsDAOFuzzing:
def test_count_matches_result_length(self, db, search, level):
"""get_total_count must be consistent with len(get_logs) when no limit truncation."""
results = db.debug_logs.get_logs(
search=search, level=level, limit=100000, offset=0,
search=search,
level=level,
limit=100000,
offset=0,
)
count = db.debug_logs.get_total_count(search=search, level=level)
assert count == len(results)
@@ -410,8 +418,8 @@ class TestDebugLogsDAOFuzzing:
# RingtoneDAO
# ===================================================================
class TestRingtoneDAOFuzzing:
class TestRingtoneDAOFuzzing:
@given(
filename=st_nasty_text.filter(lambda x: len(x) > 0),
display_name=st_nasty_text.filter(lambda x: len(x) > 0),
@@ -445,13 +453,18 @@ class TestRingtoneDAOFuzzing:
# MapDrawingsDAO
# ===================================================================
class TestMapDrawingsDAOFuzzing:
class TestMapDrawingsDAOFuzzing:
@given(
name=st_nasty_text.filter(lambda x: len(x) > 0),
data=st.one_of(
st_nasty_text,
st.builds(json.dumps, st.dictionaries(st.text(max_size=10), st.text(max_size=50), max_size=10)),
st.builds(
json.dumps,
st.dictionaries(
st.text(max_size=10), st.text(max_size=50), max_size=10
),
),
),
)
@settings(
@@ -470,8 +483,8 @@ class TestMapDrawingsDAOFuzzing:
# MessageDAO — folders
# ===================================================================
class TestMessageDAOFoldersFuzzing:
class TestMessageDAOFoldersFuzzing:
@given(name=st_nasty_text.filter(lambda x: len(x) > 0))
@settings(
deadline=None,
@@ -507,8 +520,8 @@ class TestMessageDAOFoldersFuzzing:
# MessageHandler — search
# ===================================================================
class TestMessageHandlerFuzzing:
class TestMessageHandlerFuzzing:
@given(search=st_search_term)
@settings(
deadline=None,
@@ -557,9 +570,14 @@ class TestMessageHandlerFuzzing:
suppress_health_check=[HealthCheck.function_scoped_fixture],
max_examples=40,
)
def test_get_conversation_messages_never_crashes(self, handler, dest, after_id, before_id):
def test_get_conversation_messages_never_crashes(
self, handler, dest, after_id, before_id
):
results = handler.get_conversation_messages(
"local_hash", dest, after_id=after_id, before_id=before_id,
"local_hash",
dest,
after_id=after_id,
before_id=before_id,
)
assert isinstance(results, list)
@@ -568,8 +586,8 @@ class TestMessageHandlerFuzzing:
# _safe_href — URL sanitization
# ===================================================================
class TestSafeHrefFuzzing:
class TestSafeHrefFuzzing:
@given(url=st.text(max_size=500))
@settings(deadline=None, max_examples=200)
def test_safe_href_never_returns_dangerous_protocol(self, url):
@@ -619,14 +637,20 @@ class TestSafeHrefFuzzing:
def test_safe_href_blocks_all_generated_dangerous_urls(self, url):
assert _safe_href(url) == "#"
@given(scheme=st.text(min_size=1, max_size=20).filter(lambda s: ":" not in s and "/" not in s))
@given(
scheme=st.text(min_size=1, max_size=20).filter(
lambda s: ":" not in s and "/" not in s
)
)
@settings(deadline=None, max_examples=60)
def test_safe_href_blocks_unknown_schemes(self, scheme):
"""Any unknown scheme like 'foo:bar' should be blocked."""
url = f"{scheme}:something"
result = _safe_href(url)
lower = url.strip().lower()
if any(lower.startswith(p) for p in ("https://", "http://", "/", "#", "mailto:")):
if any(
lower.startswith(p) for p in ("https://", "http://", "/", "#", "mailto:")
):
assert result == url
else:
assert result == "#"
@@ -636,31 +660,34 @@ class TestSafeHrefFuzzing:
# parse_bool_query_param
# ===================================================================
class TestParseBoolQueryParam:
class TestParseBoolQueryParam:
@given(value=st.text(max_size=100))
@settings(deadline=None, max_examples=200)
def test_never_crashes_and_returns_bool(self, value):
result = parse_bool_query_param(value)
assert isinstance(result, bool)
@pytest.mark.parametrize("value,expected", [
("1", True),
("true", True),
("True", True),
("TRUE", True),
("yes", True),
("YES", True),
("on", True),
("ON", True),
("0", False),
("false", False),
("no", False),
("off", False),
("", False),
("maybe", False),
(None, False),
])
@pytest.mark.parametrize(
"value,expected",
[
("1", True),
("true", True),
("True", True),
("TRUE", True),
("yes", True),
("YES", True),
("on", True),
("ON", True),
("0", False),
("false", False),
("no", False),
("off", False),
("", False),
("maybe", False),
(None, False),
],
)
def test_known_values(self, value, expected):
assert parse_bool_query_param(value) is expected
@@ -669,8 +696,8 @@ class TestParseBoolQueryParam:
# message_fields_have_attachments
# ===================================================================
class TestMessageFieldsHaveAttachments:
class TestMessageFieldsHaveAttachments:
@given(data=st.text(max_size=500))
@settings(deadline=None, max_examples=100)
def test_never_crashes_on_arbitrary_text(self, data):
@@ -689,25 +716,30 @@ class TestMessageFieldsHaveAttachments:
result = message_fields_have_attachments(json.dumps(obj))
assert isinstance(result, bool)
@pytest.mark.parametrize("fields_json,expected", [
(None, False),
("", False),
("{}", False),
("not json", False),
('{"image": "base64data"}', True),
('{"audio": "base64data"}', True),
('{"file_attachments": [{"name": "f.txt"}]}', True),
('{"file_attachments": []}', False),
('{"other_field": "value"}', False),
])
@pytest.mark.parametrize(
"fields_json,expected",
[
(None, False),
("", False),
("{}", False),
("not json", False),
('{"image": "base64data"}', True),
('{"audio": "base64data"}', True),
('{"file_attachments": [{"name": "f.txt"}]}', True),
('{"file_attachments": []}', False),
('{"other_field": "value"}', False),
],
)
def test_known_cases(self, fields_json, expected):
assert message_fields_have_attachments(fields_json) is expected
@given(
data=st.recursive(
st.one_of(st.text(max_size=20), st.integers(), st.booleans(), st.none()),
lambda children: st.lists(children, max_size=5)
| st.dictionaries(st.text(max_size=10), children, max_size=5),
lambda children: (
st.lists(children, max_size=5)
| st.dictionaries(st.text(max_size=10), children, max_size=5)
),
max_leaves=30,
)
)
+181 -13
View File
@@ -22,7 +22,6 @@ import tempfile
import threading
import time
import unittest
from unittest.mock import MagicMock
from meshchatx.src.backend.announce_manager import AnnounceManager
from meshchatx.src.backend.database import Database
@@ -33,6 +32,7 @@ from meshchatx.src.backend.message_handler import MessageHandler
# Helpers
# ---------------------------------------------------------------------------
def percentile(data, pct):
"""Return the pct-th percentile of sorted data."""
if not data:
@@ -108,8 +108,8 @@ def make_announce(i):
# Test class
# ---------------------------------------------------------------------------
class TestPerformanceHotPaths(unittest.TestCase):
class TestPerformanceHotPaths(unittest.TestCase):
NUM_MESSAGES = 10_000
NUM_PEERS = 200
NUM_ANNOUNCES = 5_000
@@ -134,7 +134,7 @@ class TestPerformanceHotPaths(unittest.TestCase):
@classmethod
def _seed_data(cls):
print(f"\n--- Seeding test data ---")
print("\n--- Seeding test data ---")
# Peers
cls.peer_hashes = [secrets.token_hex(16) for _ in range(cls.NUM_PEERS)]
@@ -286,7 +286,9 @@ class TestPerformanceHotPaths(unittest.TestCase):
dest = secrets.token_hex(16)
_, ms = timed_call(
self.db.announces.upsert_favourite,
dest, f"Bench Fav {i}", "lxmf.delivery",
dest,
f"Bench Fav {i}",
"lxmf.delivery",
)
durations.append(ms)
@@ -316,7 +318,13 @@ class TestPerformanceHotPaths(unittest.TestCase):
def test_conversations_search_latency(self):
"""Search conversations — LIKE across titles, content, peer hashes."""
print("\n[Conversations] Search:")
terms = ["Message title 5", "Content body", "abc", "zzz_nope", self.heavy_peer[:8]]
terms = [
"Message title 5",
"Content body",
"abc",
"zzz_nope",
self.heavy_peer[:8],
]
durations = []
for term in terms:
_, ms = timed_call(
@@ -450,7 +458,9 @@ class TestPerformanceHotPaths(unittest.TestCase):
with lock:
all_durations.extend(thread_durations)
threads = [threading.Thread(target=writer, args=(t,)) for t in range(num_threads)]
threads = [
threading.Thread(target=writer, args=(t,)) for t in range(num_threads)
]
t0 = time.perf_counter()
for t in threads:
t.start()
@@ -460,8 +470,10 @@ class TestPerformanceHotPaths(unittest.TestCase):
total_ops = num_threads * msgs_per_thread
throughput = total_ops / (wall_ms / 1000)
print(f" Wall time: {wall_ms:.0f}ms for {total_ops} inserts ({throughput:.0f} ops/s)")
stats = latency_report("concurrent_write", all_durations)
print(
f" Wall time: {wall_ms:.0f}ms for {total_ops} inserts ({throughput:.0f} ops/s)"
)
latency_report("concurrent_write", all_durations)
self.assertEqual(len(errors), 0, f"Writer errors: {errors[:5]}")
self.assertGreater(throughput, 100, "Concurrent write throughput < 100 ops/s")
@@ -487,7 +499,9 @@ class TestPerformanceHotPaths(unittest.TestCase):
with lock:
all_durations.extend(thread_durations)
threads = [threading.Thread(target=writer, args=(t,)) for t in range(num_threads)]
threads = [
threading.Thread(target=writer, args=(t,)) for t in range(num_threads)
]
t0 = time.perf_counter()
for t in threads:
t.start()
@@ -497,8 +511,10 @@ class TestPerformanceHotPaths(unittest.TestCase):
total_ops = num_threads * announces_per_thread
throughput = total_ops / (wall_ms / 1000)
print(f" Wall time: {wall_ms:.0f}ms for {total_ops} upserts ({throughput:.0f} ops/s)")
stats = latency_report("concurrent_announce_write", all_durations)
print(
f" Wall time: {wall_ms:.0f}ms for {total_ops} upserts ({throughput:.0f} ops/s)"
)
latency_report("concurrent_announce_write", all_durations)
self.assertEqual(len(errors), 0, f"Writer errors: {errors[:5]}")
self.assertGreater(throughput, 100, "Concurrent announce write < 100 ops/s")
@@ -543,8 +559,12 @@ class TestPerformanceHotPaths(unittest.TestCase):
with lock:
read_durations.extend(local_durs)
writers = [threading.Thread(target=writer, args=(t,)) for t in range(num_writers)]
readers = [threading.Thread(target=reader, args=(t,)) for t in range(num_readers)]
writers = [
threading.Thread(target=writer, args=(t,)) for t in range(num_writers)
]
readers = [
threading.Thread(target=reader, args=(t,)) for t in range(num_readers)
]
t0 = time.perf_counter()
for t in writers + readers:
@@ -593,6 +613,154 @@ class TestPerformanceHotPaths(unittest.TestCase):
print(f" Contacts LIKE search ({self.NUM_CONTACTS} rows): {ms_con:.2f}ms")
self.assertLess(ms_con, 50, "Contacts LIKE search > 50ms")
# ===================================================================
# N+1 BATCH OPERATIONS — transaction wrapping regression tests
# ===================================================================
def test_mark_conversations_as_read_batch(self):
"""mark_conversations_as_read should be fast for large batches (transaction-wrapped)."""
print("\n[Batch] mark_conversations_as_read:")
hashes = [secrets.token_hex(16) for _ in range(200)]
durations = []
for _ in range(5):
_, ms = timed_call(self.db.messages.mark_conversations_as_read, hashes)
durations.append(ms)
stats = latency_report("mark_read_200", durations)
self.assertLess(stats["p95"], 50, "mark_conversations_as_read(200) p95 > 50ms")
def test_mark_all_notifications_as_viewed_batch(self):
"""mark_all_notifications_as_viewed should be fast for large batches."""
print("\n[Batch] mark_all_notifications_as_viewed:")
hashes = [secrets.token_hex(16) for _ in range(200)]
durations = []
for _ in range(5):
_, ms = timed_call(
self.db.messages.mark_all_notifications_as_viewed, hashes
)
durations.append(ms)
stats = latency_report("mark_viewed_200", durations)
self.assertLess(
stats["p95"], 50, "mark_all_notifications_as_viewed(200) p95 > 50ms"
)
def test_move_conversations_to_folder_batch(self):
"""move_conversations_to_folder should be fast for large batches."""
print("\n[Batch] move_conversations_to_folder:")
self.db.messages.create_folder("perf_test_folder")
folders = self.db.messages.get_all_folders()
folder_id = folders[0]["id"]
hashes = [secrets.token_hex(16) for _ in range(200)]
durations = []
for _ in range(5):
_, ms = timed_call(
self.db.messages.move_conversations_to_folder, hashes, folder_id
)
durations.append(ms)
stats = latency_report("move_folder_200", durations)
self.assertLess(
stats["p95"], 50, "move_conversations_to_folder(200) p95 > 50ms"
)
# ===================================================================
# INDEX VERIFICATION — confirm new indexes are used
# ===================================================================
def test_indexes_exist(self):
"""Verify critical indexes exist in the schema."""
print("\n[Indexes] Checking critical indexes exist:")
rows = self.db.provider.fetchall(
"SELECT name FROM sqlite_master WHERE type='index'"
)
index_names = {r["name"] for r in rows}
expected = [
"idx_contacts_lxmf_address",
"idx_contacts_lxst_address",
"idx_notifications_is_viewed",
"idx_map_drawings_identity_hash",
"idx_map_drawings_identity_name",
"idx_voicemails_is_read",
"idx_archived_pages_created_at",
"idx_lxmf_messages_state_peer",
"idx_lxmf_messages_peer_hash",
"idx_lxmf_messages_peer_ts",
"idx_announces_updated_at",
"idx_announces_aspect",
]
for idx in expected:
self.assertIn(idx, index_names, f"Missing index: {idx}")
print(f" {idx}: OK")
def test_pragmas_applied(self):
"""Verify performance PRAGMAs are active."""
print("\n[PRAGMAs] Checking applied PRAGMAs:")
journal = self.db._get_pragma_value("journal_mode")
print(f" journal_mode: {journal}")
self.assertEqual(journal, "wal")
sync = self.db._get_pragma_value("synchronous")
print(f" synchronous: {sync}")
self.assertEqual(sync, 1) # NORMAL = 1
temp_store = self.db._get_pragma_value("temp_store")
print(f" temp_store: {temp_store}")
self.assertEqual(temp_store, 2) # MEMORY = 2
cache_size = self.db._get_pragma_value("cache_size")
print(f" cache_size: {cache_size}")
self.assertLessEqual(cache_size, -8000)
# ===================================================================
# QUERY PLAN CHECKS — confirm indexes are actually used
# ===================================================================
def test_query_plan_messages_by_peer(self):
"""The most common message query should use peer_hash index."""
print("\n[Query Plan] Messages by peer_hash:")
rows = self.db.provider.fetchall(
"EXPLAIN QUERY PLAN SELECT * FROM lxmf_messages WHERE peer_hash = ? ORDER BY id DESC LIMIT 50",
("test",),
)
plan = " ".join(str(r["detail"]) for r in rows)
print(f" {plan}")
self.assertIn("idx_lxmf_messages_peer_hash", plan.lower())
def test_query_plan_announces_by_aspect(self):
"""Announce filtering by aspect should use the aspect index."""
print("\n[Query Plan] Announces by aspect:")
rows = self.db.provider.fetchall(
"EXPLAIN QUERY PLAN SELECT * FROM announces WHERE aspect = ? ORDER BY updated_at DESC LIMIT 50",
("lxmf.delivery",),
)
plan = " ".join(str(r["detail"]) for r in rows)
print(f" {plan}")
self.assertIn("idx_announces_aspect", plan.lower())
def test_query_plan_failed_messages_state_peer(self):
"""The failed_count subquery should use the state+peer composite index."""
print("\n[Query Plan] Failed messages (state, peer_hash):")
rows = self.db.provider.fetchall(
"EXPLAIN QUERY PLAN SELECT COUNT(*) FROM lxmf_messages WHERE state = 'failed' AND peer_hash = ?",
("test",),
)
plan = " ".join(str(r["detail"]) for r in rows)
print(f" {plan}")
self.assertIn("idx_lxmf_messages_state_peer", plan.lower())
def test_query_plan_notifications_unread(self):
"""Notification unread filter should use the is_viewed index."""
print("\n[Query Plan] Notifications unread:")
rows = self.db.provider.fetchall(
"EXPLAIN QUERY PLAN SELECT * FROM notifications WHERE is_viewed = 0 ORDER BY timestamp DESC LIMIT 50",
)
plan = " ".join(str(r["detail"]) for r in rows)
print(f" {plan}")
self.assertIn("idx_notifications_is_viewed", plan.lower())
if __name__ == "__main__":
unittest.main()
+9 -14
View File
@@ -9,7 +9,6 @@ Covers:
import os
import re
import sqlite3
import tempfile
import pytest
from hypothesis import HealthCheck, given, settings
@@ -24,6 +23,7 @@ from meshchatx.src.backend.database.schema import DatabaseSchema, _validate_iden
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def db_env(tmp_path):
"""Provide an initialized DatabaseProvider + Schema in a temp directory."""
@@ -52,8 +52,8 @@ def _make_legacy_db(legacy_dir, identity_hash, tables_sql):
# 1. ATTACH DATABASE — single-quote escaping
# ---------------------------------------------------------------------------
class TestAttachDatabasePathEscaping:
class TestAttachDatabasePathEscaping:
def test_path_without_quotes_migrates_normally(self, db_env):
provider, _schema, tmp_path = db_env
legacy_dir = str(tmp_path / "legacy_normal")
@@ -153,8 +153,8 @@ class TestAttachDatabasePathEscaping:
# 2. Legacy migrator — malicious column names filtered out
# ---------------------------------------------------------------------------
class TestLegacyColumnFiltering:
class TestLegacyColumnFiltering:
def test_normal_columns_migrate(self, db_env):
"""Standard column names pass through the identifier filter."""
provider, _schema, tmp_path = db_env
@@ -188,9 +188,7 @@ class TestLegacyColumnFiltering:
conn.execute(
'CREATE TABLE config (key TEXT UNIQUE, value TEXT, "key; DROP TABLE config" TEXT)'
)
conn.execute(
"INSERT INTO config (key, value) VALUES ('safe', 'data')"
)
conn.execute("INSERT INTO config (key, value) VALUES ('safe', 'data')")
conn.commit()
conn.close()
@@ -211,12 +209,8 @@ class TestLegacyColumnFiltering:
os.makedirs(identity_dir, exist_ok=True)
db_path = os.path.join(identity_dir, "database.db")
conn = sqlite3.connect(db_path)
conn.execute(
'CREATE TABLE config (key TEXT UNIQUE, value TEXT, "evil()" TEXT)'
)
conn.execute(
"INSERT INTO config (key, value) VALUES ('p1', 'parens')"
)
conn.execute('CREATE TABLE config (key TEXT UNIQUE, value TEXT, "evil()" TEXT)')
conn.execute("INSERT INTO config (key, value) VALUES ('p1', 'parens')")
conn.commit()
conn.close()
@@ -231,8 +225,8 @@ class TestLegacyColumnFiltering:
# 3. _validate_identifier — unit tests
# ---------------------------------------------------------------------------
class TestValidateIdentifier:
class TestValidateIdentifier:
@pytest.mark.parametrize(
"name",
[
@@ -273,8 +267,8 @@ class TestValidateIdentifier:
# 4. _ensure_column — rejects injection via table/column names
# ---------------------------------------------------------------------------
class TestEnsureColumnInjection:
class TestEnsureColumnInjection:
def test_ensure_column_rejects_malicious_table_name(self, db_env):
_provider, schema, _tmp_path = db_env
with pytest.raises(ValueError, match="Invalid SQL table name"):
@@ -354,6 +348,7 @@ def test_validate_identifier_rejects_pure_metacharacter_strings(name):
# 6. ATTACH path escaping — property-based
# ---------------------------------------------------------------------------
@given(
path_segment=st.text(
alphabet=st.characters(