mirror of
https://git.quad4.io/RNS-Things/MeshChatX.git
synced 2026-05-15 10:55:08 +00:00
Improve database performance and SQL handling
- Introduced SQLite pragma tuning in the Database initialization for improved performance. - Wrapped multiple database operations in transactions to optimize batch processing in MessageDAO. - Updated DatabaseSchema to version 39, adding new indexes for better query performance. - Improved test coverage for batch operations and SQL injection scenarios in the DAO layer.
This commit is contained in:
@@ -43,6 +43,7 @@ class Database:
|
||||
self.debug_logs = DebugLogsDAO(self.provider)
|
||||
|
||||
def initialize(self):
|
||||
self._tune_sqlite_pragmas()
|
||||
self.schema.initialize()
|
||||
|
||||
def migrate_from_legacy(self, reticulum_config_dir, identity_hash_hex):
|
||||
@@ -60,9 +61,12 @@ class Database:
|
||||
|
||||
def _tune_sqlite_pragmas(self):
|
||||
try:
|
||||
self.execute_sql("PRAGMA journal_mode=WAL")
|
||||
self.execute_sql("PRAGMA synchronous=NORMAL")
|
||||
self.execute_sql("PRAGMA wal_autocheckpoint=1000")
|
||||
self.execute_sql("PRAGMA temp_store=MEMORY")
|
||||
self.execute_sql("PRAGMA journal_mode=WAL")
|
||||
self.execute_sql("PRAGMA cache_size=-8000") # 8 MB
|
||||
self.execute_sql("PRAGMA mmap_size=67108864") # 64 MB
|
||||
except Exception as exc:
|
||||
print(f"SQLite pragma setup failed: {exc}")
|
||||
|
||||
|
||||
@@ -135,17 +135,18 @@ class MessageDAO:
|
||||
if not destination_hashes:
|
||||
return
|
||||
now = datetime.now(UTC).isoformat()
|
||||
for destination_hash in destination_hashes:
|
||||
self.provider.execute(
|
||||
"""
|
||||
INSERT INTO lxmf_conversation_read_state (destination_hash, last_read_at, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(destination_hash) DO UPDATE SET
|
||||
last_read_at = EXCLUDED.last_read_at,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
""",
|
||||
(destination_hash, now, now, now),
|
||||
)
|
||||
with self.provider:
|
||||
for destination_hash in destination_hashes:
|
||||
self.provider.execute(
|
||||
"""
|
||||
INSERT INTO lxmf_conversation_read_state (destination_hash, last_read_at, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(destination_hash) DO UPDATE SET
|
||||
last_read_at = EXCLUDED.last_read_at,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
""",
|
||||
(destination_hash, now, now, now),
|
||||
)
|
||||
|
||||
def is_conversation_unread(self, destination_hash):
|
||||
row = self.provider.fetchone(
|
||||
@@ -310,17 +311,18 @@ class MessageDAO:
|
||||
def mark_all_notifications_as_viewed(self, destination_hashes=None):
|
||||
now = datetime.now(UTC).isoformat()
|
||||
if destination_hashes:
|
||||
for destination_hash in destination_hashes:
|
||||
self.provider.execute(
|
||||
"""
|
||||
INSERT INTO notification_viewed_state (destination_hash, last_viewed_at, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(destination_hash) DO UPDATE SET
|
||||
last_viewed_at = EXCLUDED.last_viewed_at,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
""",
|
||||
(destination_hash, now, now, now),
|
||||
)
|
||||
with self.provider:
|
||||
for destination_hash in destination_hashes:
|
||||
self.provider.execute(
|
||||
"""
|
||||
INSERT INTO notification_viewed_state (destination_hash, last_viewed_at, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(destination_hash) DO UPDATE SET
|
||||
last_viewed_at = EXCLUDED.last_viewed_at,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
""",
|
||||
(destination_hash, now, now, now),
|
||||
)
|
||||
else:
|
||||
# mark all conversations as viewed
|
||||
self.provider.execute(
|
||||
@@ -397,8 +399,9 @@ class MessageDAO:
|
||||
)
|
||||
|
||||
def move_conversations_to_folder(self, peer_hashes, folder_id):
|
||||
for peer_hash in peer_hashes:
|
||||
self.move_conversation_to_folder(peer_hash, folder_id)
|
||||
with self.provider:
|
||||
for peer_hash in peer_hashes:
|
||||
self.move_conversation_to_folder(peer_hash, folder_id)
|
||||
|
||||
def get_all_conversation_folders(self):
|
||||
return self.provider.fetchall("SELECT * FROM lxmf_conversation_folders")
|
||||
|
||||
@@ -13,7 +13,7 @@ def _validate_identifier(name: str, label: str = "identifier") -> str:
|
||||
|
||||
|
||||
class DatabaseSchema:
|
||||
LATEST_VERSION = 38
|
||||
LATEST_VERSION = 39
|
||||
|
||||
def __init__(self, provider: DatabaseProvider):
|
||||
self.provider = provider
|
||||
@@ -1013,6 +1013,39 @@ class DatabaseSchema:
|
||||
"CREATE INDEX IF NOT EXISTS idx_lxmf_messages_reply_to_hash ON lxmf_messages(reply_to_hash)",
|
||||
)
|
||||
|
||||
if current_version < 39:
|
||||
# Indexes for contacts JOIN columns (used in message_handler.get_conversations
|
||||
# and announce_manager.get_filtered_announces OR-based JOINs)
|
||||
self._safe_execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_contacts_lxmf_address ON contacts(lxmf_address)",
|
||||
)
|
||||
self._safe_execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_contacts_lxst_address ON contacts(lxst_address)",
|
||||
)
|
||||
# Notifications: filter by is_viewed
|
||||
self._safe_execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_notifications_is_viewed ON notifications(is_viewed)",
|
||||
)
|
||||
# Map drawings: lookup by identity_hash
|
||||
self._safe_execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_map_drawings_identity_hash ON map_drawings(identity_hash)",
|
||||
)
|
||||
self._safe_execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_map_drawings_identity_name ON map_drawings(identity_hash, name)",
|
||||
)
|
||||
# Voicemails: filter by is_read
|
||||
self._safe_execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_voicemails_is_read ON voicemails(is_read)",
|
||||
)
|
||||
# Archived pages: ORDER BY created_at for cleanup queries
|
||||
self._safe_execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_archived_pages_created_at ON archived_pages(created_at)",
|
||||
)
|
||||
# Conversation message state+peer: for failed_count subquery
|
||||
self._safe_execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_lxmf_messages_state_peer ON lxmf_messages(state, peer_hash)",
|
||||
)
|
||||
|
||||
# Update version in config
|
||||
self._safe_execute(
|
||||
"""
|
||||
|
||||
@@ -42,22 +42,24 @@ st_nasty_text = st.text(
|
||||
max_size=300,
|
||||
)
|
||||
|
||||
st_sql_payloads = st.sampled_from([
|
||||
"'; DROP TABLE config; --",
|
||||
"' OR '1'='1",
|
||||
"\" OR \"1\"=\"1",
|
||||
"1; SELECT * FROM sqlite_master",
|
||||
"Robert'); DROP TABLE lxmf_messages;--",
|
||||
"%",
|
||||
"%%",
|
||||
"_",
|
||||
"\x00",
|
||||
"' UNION SELECT key, value FROM config --",
|
||||
"NULL",
|
||||
"1=1",
|
||||
"admin'--",
|
||||
"' AND 1=CONVERT(int,(SELECT TOP 1 table_name FROM information_schema.tables))--",
|
||||
])
|
||||
st_sql_payloads = st.sampled_from(
|
||||
[
|
||||
"'; DROP TABLE config; --",
|
||||
"' OR '1'='1",
|
||||
'" OR "1"="1',
|
||||
"1; SELECT * FROM sqlite_master",
|
||||
"Robert'); DROP TABLE lxmf_messages;--",
|
||||
"%",
|
||||
"%%",
|
||||
"_",
|
||||
"\x00",
|
||||
"' UNION SELECT key, value FROM config --",
|
||||
"NULL",
|
||||
"1=1",
|
||||
"admin'--",
|
||||
"' AND 1=CONVERT(int,(SELECT TOP 1 table_name FROM information_schema.tables))--",
|
||||
]
|
||||
)
|
||||
|
||||
st_search_term = st.one_of(st_nasty_text, st_sql_payloads)
|
||||
|
||||
@@ -68,6 +70,7 @@ st_hex_hash = st.from_regex(r"[0-9a-f]{16,64}", fullmatch=True)
|
||||
# Fixture: initialised in-memory database
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
database = Database(":memory:")
|
||||
@@ -85,8 +88,8 @@ def handler(db):
|
||||
# ContactsDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestContactsDAOFuzzing:
|
||||
|
||||
class TestContactsDAOFuzzing:
|
||||
@given(
|
||||
name=st_nasty_text,
|
||||
identity_hash=st_hex_hash,
|
||||
@@ -144,8 +147,8 @@ class TestContactsDAOFuzzing:
|
||||
# ConfigDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestConfigDAOFuzzing:
|
||||
|
||||
class TestConfigDAOFuzzing:
|
||||
@given(key=st_nasty_text.filter(lambda x: len(x) > 0), value=st_nasty_text)
|
||||
@settings(
|
||||
deadline=None,
|
||||
@@ -187,8 +190,8 @@ class TestConfigDAOFuzzing:
|
||||
# MiscDAO — spam keywords, notifications, keyboard shortcuts
|
||||
# ===================================================================
|
||||
|
||||
class TestMiscDAOFuzzing:
|
||||
|
||||
class TestMiscDAOFuzzing:
|
||||
@given(keyword=st_nasty_text.filter(lambda x: len(x) > 0))
|
||||
@settings(
|
||||
deadline=None,
|
||||
@@ -221,7 +224,9 @@ class TestMiscDAOFuzzing:
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_add_notification_never_crashes(self, db, ntype, remote_hash, title, content):
|
||||
def test_add_notification_never_crashes(
|
||||
self, db, ntype, remote_hash, title, content
|
||||
):
|
||||
db.misc.add_notification(ntype, remote_hash, title, content)
|
||||
notifications = db.misc.get_notifications()
|
||||
assert isinstance(notifications, list)
|
||||
@@ -262,8 +267,8 @@ class TestMiscDAOFuzzing:
|
||||
# TelephoneDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestTelephoneDAOFuzzing:
|
||||
|
||||
class TestTelephoneDAOFuzzing:
|
||||
@given(
|
||||
name=st_nasty_text,
|
||||
identity_hash=st_hex_hash,
|
||||
@@ -314,8 +319,8 @@ class TestTelephoneDAOFuzzing:
|
||||
# VoicemailDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestVoicemailDAOFuzzing:
|
||||
|
||||
class TestVoicemailDAOFuzzing:
|
||||
@given(
|
||||
name=st_nasty_text,
|
||||
identity_hash=st_hex_hash,
|
||||
@@ -354,8 +359,8 @@ class TestVoicemailDAOFuzzing:
|
||||
# DebugLogsDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestDebugLogsDAOFuzzing:
|
||||
|
||||
class TestDebugLogsDAOFuzzing:
|
||||
@given(
|
||||
level=st.sampled_from(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
|
||||
module=st_nasty_text,
|
||||
@@ -400,7 +405,10 @@ class TestDebugLogsDAOFuzzing:
|
||||
def test_count_matches_result_length(self, db, search, level):
|
||||
"""get_total_count must be consistent with len(get_logs) when no limit truncation."""
|
||||
results = db.debug_logs.get_logs(
|
||||
search=search, level=level, limit=100000, offset=0,
|
||||
search=search,
|
||||
level=level,
|
||||
limit=100000,
|
||||
offset=0,
|
||||
)
|
||||
count = db.debug_logs.get_total_count(search=search, level=level)
|
||||
assert count == len(results)
|
||||
@@ -410,8 +418,8 @@ class TestDebugLogsDAOFuzzing:
|
||||
# RingtoneDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestRingtoneDAOFuzzing:
|
||||
|
||||
class TestRingtoneDAOFuzzing:
|
||||
@given(
|
||||
filename=st_nasty_text.filter(lambda x: len(x) > 0),
|
||||
display_name=st_nasty_text.filter(lambda x: len(x) > 0),
|
||||
@@ -445,13 +453,18 @@ class TestRingtoneDAOFuzzing:
|
||||
# MapDrawingsDAO
|
||||
# ===================================================================
|
||||
|
||||
class TestMapDrawingsDAOFuzzing:
|
||||
|
||||
class TestMapDrawingsDAOFuzzing:
|
||||
@given(
|
||||
name=st_nasty_text.filter(lambda x: len(x) > 0),
|
||||
data=st.one_of(
|
||||
st_nasty_text,
|
||||
st.builds(json.dumps, st.dictionaries(st.text(max_size=10), st.text(max_size=50), max_size=10)),
|
||||
st.builds(
|
||||
json.dumps,
|
||||
st.dictionaries(
|
||||
st.text(max_size=10), st.text(max_size=50), max_size=10
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
@settings(
|
||||
@@ -470,8 +483,8 @@ class TestMapDrawingsDAOFuzzing:
|
||||
# MessageDAO — folders
|
||||
# ===================================================================
|
||||
|
||||
class TestMessageDAOFoldersFuzzing:
|
||||
|
||||
class TestMessageDAOFoldersFuzzing:
|
||||
@given(name=st_nasty_text.filter(lambda x: len(x) > 0))
|
||||
@settings(
|
||||
deadline=None,
|
||||
@@ -507,8 +520,8 @@ class TestMessageDAOFoldersFuzzing:
|
||||
# MessageHandler — search
|
||||
# ===================================================================
|
||||
|
||||
class TestMessageHandlerFuzzing:
|
||||
|
||||
class TestMessageHandlerFuzzing:
|
||||
@given(search=st_search_term)
|
||||
@settings(
|
||||
deadline=None,
|
||||
@@ -557,9 +570,14 @@ class TestMessageHandlerFuzzing:
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
max_examples=40,
|
||||
)
|
||||
def test_get_conversation_messages_never_crashes(self, handler, dest, after_id, before_id):
|
||||
def test_get_conversation_messages_never_crashes(
|
||||
self, handler, dest, after_id, before_id
|
||||
):
|
||||
results = handler.get_conversation_messages(
|
||||
"local_hash", dest, after_id=after_id, before_id=before_id,
|
||||
"local_hash",
|
||||
dest,
|
||||
after_id=after_id,
|
||||
before_id=before_id,
|
||||
)
|
||||
assert isinstance(results, list)
|
||||
|
||||
@@ -568,8 +586,8 @@ class TestMessageHandlerFuzzing:
|
||||
# _safe_href — URL sanitization
|
||||
# ===================================================================
|
||||
|
||||
class TestSafeHrefFuzzing:
|
||||
|
||||
class TestSafeHrefFuzzing:
|
||||
@given(url=st.text(max_size=500))
|
||||
@settings(deadline=None, max_examples=200)
|
||||
def test_safe_href_never_returns_dangerous_protocol(self, url):
|
||||
@@ -619,14 +637,20 @@ class TestSafeHrefFuzzing:
|
||||
def test_safe_href_blocks_all_generated_dangerous_urls(self, url):
|
||||
assert _safe_href(url) == "#"
|
||||
|
||||
@given(scheme=st.text(min_size=1, max_size=20).filter(lambda s: ":" not in s and "/" not in s))
|
||||
@given(
|
||||
scheme=st.text(min_size=1, max_size=20).filter(
|
||||
lambda s: ":" not in s and "/" not in s
|
||||
)
|
||||
)
|
||||
@settings(deadline=None, max_examples=60)
|
||||
def test_safe_href_blocks_unknown_schemes(self, scheme):
|
||||
"""Any unknown scheme like 'foo:bar' should be blocked."""
|
||||
url = f"{scheme}:something"
|
||||
result = _safe_href(url)
|
||||
lower = url.strip().lower()
|
||||
if any(lower.startswith(p) for p in ("https://", "http://", "/", "#", "mailto:")):
|
||||
if any(
|
||||
lower.startswith(p) for p in ("https://", "http://", "/", "#", "mailto:")
|
||||
):
|
||||
assert result == url
|
||||
else:
|
||||
assert result == "#"
|
||||
@@ -636,31 +660,34 @@ class TestSafeHrefFuzzing:
|
||||
# parse_bool_query_param
|
||||
# ===================================================================
|
||||
|
||||
class TestParseBoolQueryParam:
|
||||
|
||||
class TestParseBoolQueryParam:
|
||||
@given(value=st.text(max_size=100))
|
||||
@settings(deadline=None, max_examples=200)
|
||||
def test_never_crashes_and_returns_bool(self, value):
|
||||
result = parse_bool_query_param(value)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
@pytest.mark.parametrize("value,expected", [
|
||||
("1", True),
|
||||
("true", True),
|
||||
("True", True),
|
||||
("TRUE", True),
|
||||
("yes", True),
|
||||
("YES", True),
|
||||
("on", True),
|
||||
("ON", True),
|
||||
("0", False),
|
||||
("false", False),
|
||||
("no", False),
|
||||
("off", False),
|
||||
("", False),
|
||||
("maybe", False),
|
||||
(None, False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"value,expected",
|
||||
[
|
||||
("1", True),
|
||||
("true", True),
|
||||
("True", True),
|
||||
("TRUE", True),
|
||||
("yes", True),
|
||||
("YES", True),
|
||||
("on", True),
|
||||
("ON", True),
|
||||
("0", False),
|
||||
("false", False),
|
||||
("no", False),
|
||||
("off", False),
|
||||
("", False),
|
||||
("maybe", False),
|
||||
(None, False),
|
||||
],
|
||||
)
|
||||
def test_known_values(self, value, expected):
|
||||
assert parse_bool_query_param(value) is expected
|
||||
|
||||
@@ -669,8 +696,8 @@ class TestParseBoolQueryParam:
|
||||
# message_fields_have_attachments
|
||||
# ===================================================================
|
||||
|
||||
class TestMessageFieldsHaveAttachments:
|
||||
|
||||
class TestMessageFieldsHaveAttachments:
|
||||
@given(data=st.text(max_size=500))
|
||||
@settings(deadline=None, max_examples=100)
|
||||
def test_never_crashes_on_arbitrary_text(self, data):
|
||||
@@ -689,25 +716,30 @@ class TestMessageFieldsHaveAttachments:
|
||||
result = message_fields_have_attachments(json.dumps(obj))
|
||||
assert isinstance(result, bool)
|
||||
|
||||
@pytest.mark.parametrize("fields_json,expected", [
|
||||
(None, False),
|
||||
("", False),
|
||||
("{}", False),
|
||||
("not json", False),
|
||||
('{"image": "base64data"}', True),
|
||||
('{"audio": "base64data"}', True),
|
||||
('{"file_attachments": [{"name": "f.txt"}]}', True),
|
||||
('{"file_attachments": []}', False),
|
||||
('{"other_field": "value"}', False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"fields_json,expected",
|
||||
[
|
||||
(None, False),
|
||||
("", False),
|
||||
("{}", False),
|
||||
("not json", False),
|
||||
('{"image": "base64data"}', True),
|
||||
('{"audio": "base64data"}', True),
|
||||
('{"file_attachments": [{"name": "f.txt"}]}', True),
|
||||
('{"file_attachments": []}', False),
|
||||
('{"other_field": "value"}', False),
|
||||
],
|
||||
)
|
||||
def test_known_cases(self, fields_json, expected):
|
||||
assert message_fields_have_attachments(fields_json) is expected
|
||||
|
||||
@given(
|
||||
data=st.recursive(
|
||||
st.one_of(st.text(max_size=20), st.integers(), st.booleans(), st.none()),
|
||||
lambda children: st.lists(children, max_size=5)
|
||||
| st.dictionaries(st.text(max_size=10), children, max_size=5),
|
||||
lambda children: (
|
||||
st.lists(children, max_size=5)
|
||||
| st.dictionaries(st.text(max_size=10), children, max_size=5)
|
||||
),
|
||||
max_leaves=30,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -22,7 +22,6 @@ import tempfile
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from meshchatx.src.backend.announce_manager import AnnounceManager
|
||||
from meshchatx.src.backend.database import Database
|
||||
@@ -33,6 +32,7 @@ from meshchatx.src.backend.message_handler import MessageHandler
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def percentile(data, pct):
|
||||
"""Return the pct-th percentile of sorted data."""
|
||||
if not data:
|
||||
@@ -108,8 +108,8 @@ def make_announce(i):
|
||||
# Test class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPerformanceHotPaths(unittest.TestCase):
|
||||
|
||||
class TestPerformanceHotPaths(unittest.TestCase):
|
||||
NUM_MESSAGES = 10_000
|
||||
NUM_PEERS = 200
|
||||
NUM_ANNOUNCES = 5_000
|
||||
@@ -134,7 +134,7 @@ class TestPerformanceHotPaths(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def _seed_data(cls):
|
||||
print(f"\n--- Seeding test data ---")
|
||||
print("\n--- Seeding test data ---")
|
||||
|
||||
# Peers
|
||||
cls.peer_hashes = [secrets.token_hex(16) for _ in range(cls.NUM_PEERS)]
|
||||
@@ -286,7 +286,9 @@ class TestPerformanceHotPaths(unittest.TestCase):
|
||||
dest = secrets.token_hex(16)
|
||||
_, ms = timed_call(
|
||||
self.db.announces.upsert_favourite,
|
||||
dest, f"Bench Fav {i}", "lxmf.delivery",
|
||||
dest,
|
||||
f"Bench Fav {i}",
|
||||
"lxmf.delivery",
|
||||
)
|
||||
durations.append(ms)
|
||||
|
||||
@@ -316,7 +318,13 @@ class TestPerformanceHotPaths(unittest.TestCase):
|
||||
def test_conversations_search_latency(self):
|
||||
"""Search conversations — LIKE across titles, content, peer hashes."""
|
||||
print("\n[Conversations] Search:")
|
||||
terms = ["Message title 5", "Content body", "abc", "zzz_nope", self.heavy_peer[:8]]
|
||||
terms = [
|
||||
"Message title 5",
|
||||
"Content body",
|
||||
"abc",
|
||||
"zzz_nope",
|
||||
self.heavy_peer[:8],
|
||||
]
|
||||
durations = []
|
||||
for term in terms:
|
||||
_, ms = timed_call(
|
||||
@@ -450,7 +458,9 @@ class TestPerformanceHotPaths(unittest.TestCase):
|
||||
with lock:
|
||||
all_durations.extend(thread_durations)
|
||||
|
||||
threads = [threading.Thread(target=writer, args=(t,)) for t in range(num_threads)]
|
||||
threads = [
|
||||
threading.Thread(target=writer, args=(t,)) for t in range(num_threads)
|
||||
]
|
||||
t0 = time.perf_counter()
|
||||
for t in threads:
|
||||
t.start()
|
||||
@@ -460,8 +470,10 @@ class TestPerformanceHotPaths(unittest.TestCase):
|
||||
|
||||
total_ops = num_threads * msgs_per_thread
|
||||
throughput = total_ops / (wall_ms / 1000)
|
||||
print(f" Wall time: {wall_ms:.0f}ms for {total_ops} inserts ({throughput:.0f} ops/s)")
|
||||
stats = latency_report("concurrent_write", all_durations)
|
||||
print(
|
||||
f" Wall time: {wall_ms:.0f}ms for {total_ops} inserts ({throughput:.0f} ops/s)"
|
||||
)
|
||||
latency_report("concurrent_write", all_durations)
|
||||
|
||||
self.assertEqual(len(errors), 0, f"Writer errors: {errors[:5]}")
|
||||
self.assertGreater(throughput, 100, "Concurrent write throughput < 100 ops/s")
|
||||
@@ -487,7 +499,9 @@ class TestPerformanceHotPaths(unittest.TestCase):
|
||||
with lock:
|
||||
all_durations.extend(thread_durations)
|
||||
|
||||
threads = [threading.Thread(target=writer, args=(t,)) for t in range(num_threads)]
|
||||
threads = [
|
||||
threading.Thread(target=writer, args=(t,)) for t in range(num_threads)
|
||||
]
|
||||
t0 = time.perf_counter()
|
||||
for t in threads:
|
||||
t.start()
|
||||
@@ -497,8 +511,10 @@ class TestPerformanceHotPaths(unittest.TestCase):
|
||||
|
||||
total_ops = num_threads * announces_per_thread
|
||||
throughput = total_ops / (wall_ms / 1000)
|
||||
print(f" Wall time: {wall_ms:.0f}ms for {total_ops} upserts ({throughput:.0f} ops/s)")
|
||||
stats = latency_report("concurrent_announce_write", all_durations)
|
||||
print(
|
||||
f" Wall time: {wall_ms:.0f}ms for {total_ops} upserts ({throughput:.0f} ops/s)"
|
||||
)
|
||||
latency_report("concurrent_announce_write", all_durations)
|
||||
|
||||
self.assertEqual(len(errors), 0, f"Writer errors: {errors[:5]}")
|
||||
self.assertGreater(throughput, 100, "Concurrent announce write < 100 ops/s")
|
||||
@@ -543,8 +559,12 @@ class TestPerformanceHotPaths(unittest.TestCase):
|
||||
with lock:
|
||||
read_durations.extend(local_durs)
|
||||
|
||||
writers = [threading.Thread(target=writer, args=(t,)) for t in range(num_writers)]
|
||||
readers = [threading.Thread(target=reader, args=(t,)) for t in range(num_readers)]
|
||||
writers = [
|
||||
threading.Thread(target=writer, args=(t,)) for t in range(num_writers)
|
||||
]
|
||||
readers = [
|
||||
threading.Thread(target=reader, args=(t,)) for t in range(num_readers)
|
||||
]
|
||||
|
||||
t0 = time.perf_counter()
|
||||
for t in writers + readers:
|
||||
@@ -593,6 +613,154 @@ class TestPerformanceHotPaths(unittest.TestCase):
|
||||
print(f" Contacts LIKE search ({self.NUM_CONTACTS} rows): {ms_con:.2f}ms")
|
||||
self.assertLess(ms_con, 50, "Contacts LIKE search > 50ms")
|
||||
|
||||
# ===================================================================
|
||||
# N+1 BATCH OPERATIONS — transaction wrapping regression tests
|
||||
# ===================================================================
|
||||
|
||||
def test_mark_conversations_as_read_batch(self):
|
||||
"""mark_conversations_as_read should be fast for large batches (transaction-wrapped)."""
|
||||
print("\n[Batch] mark_conversations_as_read:")
|
||||
hashes = [secrets.token_hex(16) for _ in range(200)]
|
||||
durations = []
|
||||
for _ in range(5):
|
||||
_, ms = timed_call(self.db.messages.mark_conversations_as_read, hashes)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("mark_read_200", durations)
|
||||
self.assertLess(stats["p95"], 50, "mark_conversations_as_read(200) p95 > 50ms")
|
||||
|
||||
def test_mark_all_notifications_as_viewed_batch(self):
|
||||
"""mark_all_notifications_as_viewed should be fast for large batches."""
|
||||
print("\n[Batch] mark_all_notifications_as_viewed:")
|
||||
hashes = [secrets.token_hex(16) for _ in range(200)]
|
||||
durations = []
|
||||
for _ in range(5):
|
||||
_, ms = timed_call(
|
||||
self.db.messages.mark_all_notifications_as_viewed, hashes
|
||||
)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("mark_viewed_200", durations)
|
||||
self.assertLess(
|
||||
stats["p95"], 50, "mark_all_notifications_as_viewed(200) p95 > 50ms"
|
||||
)
|
||||
|
||||
def test_move_conversations_to_folder_batch(self):
|
||||
"""move_conversations_to_folder should be fast for large batches."""
|
||||
print("\n[Batch] move_conversations_to_folder:")
|
||||
self.db.messages.create_folder("perf_test_folder")
|
||||
folders = self.db.messages.get_all_folders()
|
||||
folder_id = folders[0]["id"]
|
||||
|
||||
hashes = [secrets.token_hex(16) for _ in range(200)]
|
||||
durations = []
|
||||
for _ in range(5):
|
||||
_, ms = timed_call(
|
||||
self.db.messages.move_conversations_to_folder, hashes, folder_id
|
||||
)
|
||||
durations.append(ms)
|
||||
|
||||
stats = latency_report("move_folder_200", durations)
|
||||
self.assertLess(
|
||||
stats["p95"], 50, "move_conversations_to_folder(200) p95 > 50ms"
|
||||
)
|
||||
|
||||
# ===================================================================
|
||||
# INDEX VERIFICATION — confirm new indexes are used
|
||||
# ===================================================================
|
||||
|
||||
def test_indexes_exist(self):
|
||||
"""Verify critical indexes exist in the schema."""
|
||||
print("\n[Indexes] Checking critical indexes exist:")
|
||||
rows = self.db.provider.fetchall(
|
||||
"SELECT name FROM sqlite_master WHERE type='index'"
|
||||
)
|
||||
index_names = {r["name"] for r in rows}
|
||||
|
||||
expected = [
|
||||
"idx_contacts_lxmf_address",
|
||||
"idx_contacts_lxst_address",
|
||||
"idx_notifications_is_viewed",
|
||||
"idx_map_drawings_identity_hash",
|
||||
"idx_map_drawings_identity_name",
|
||||
"idx_voicemails_is_read",
|
||||
"idx_archived_pages_created_at",
|
||||
"idx_lxmf_messages_state_peer",
|
||||
"idx_lxmf_messages_peer_hash",
|
||||
"idx_lxmf_messages_peer_ts",
|
||||
"idx_announces_updated_at",
|
||||
"idx_announces_aspect",
|
||||
]
|
||||
for idx in expected:
|
||||
self.assertIn(idx, index_names, f"Missing index: {idx}")
|
||||
print(f" {idx}: OK")
|
||||
|
||||
def test_pragmas_applied(self):
|
||||
"""Verify performance PRAGMAs are active."""
|
||||
print("\n[PRAGMAs] Checking applied PRAGMAs:")
|
||||
journal = self.db._get_pragma_value("journal_mode")
|
||||
print(f" journal_mode: {journal}")
|
||||
self.assertEqual(journal, "wal")
|
||||
|
||||
sync = self.db._get_pragma_value("synchronous")
|
||||
print(f" synchronous: {sync}")
|
||||
self.assertEqual(sync, 1) # NORMAL = 1
|
||||
|
||||
temp_store = self.db._get_pragma_value("temp_store")
|
||||
print(f" temp_store: {temp_store}")
|
||||
self.assertEqual(temp_store, 2) # MEMORY = 2
|
||||
|
||||
cache_size = self.db._get_pragma_value("cache_size")
|
||||
print(f" cache_size: {cache_size}")
|
||||
self.assertLessEqual(cache_size, -8000)
|
||||
|
||||
# ===================================================================
|
||||
# QUERY PLAN CHECKS — confirm indexes are actually used
|
||||
# ===================================================================
|
||||
|
||||
def test_query_plan_messages_by_peer(self):
|
||||
"""The most common message query should use peer_hash index."""
|
||||
print("\n[Query Plan] Messages by peer_hash:")
|
||||
rows = self.db.provider.fetchall(
|
||||
"EXPLAIN QUERY PLAN SELECT * FROM lxmf_messages WHERE peer_hash = ? ORDER BY id DESC LIMIT 50",
|
||||
("test",),
|
||||
)
|
||||
plan = " ".join(str(r["detail"]) for r in rows)
|
||||
print(f" {plan}")
|
||||
self.assertIn("idx_lxmf_messages_peer_hash", plan.lower())
|
||||
|
||||
def test_query_plan_announces_by_aspect(self):
|
||||
"""Announce filtering by aspect should use the aspect index."""
|
||||
print("\n[Query Plan] Announces by aspect:")
|
||||
rows = self.db.provider.fetchall(
|
||||
"EXPLAIN QUERY PLAN SELECT * FROM announces WHERE aspect = ? ORDER BY updated_at DESC LIMIT 50",
|
||||
("lxmf.delivery",),
|
||||
)
|
||||
plan = " ".join(str(r["detail"]) for r in rows)
|
||||
print(f" {plan}")
|
||||
self.assertIn("idx_announces_aspect", plan.lower())
|
||||
|
||||
def test_query_plan_failed_messages_state_peer(self):
|
||||
"""The failed_count subquery should use the state+peer composite index."""
|
||||
print("\n[Query Plan] Failed messages (state, peer_hash):")
|
||||
rows = self.db.provider.fetchall(
|
||||
"EXPLAIN QUERY PLAN SELECT COUNT(*) FROM lxmf_messages WHERE state = 'failed' AND peer_hash = ?",
|
||||
("test",),
|
||||
)
|
||||
plan = " ".join(str(r["detail"]) for r in rows)
|
||||
print(f" {plan}")
|
||||
self.assertIn("idx_lxmf_messages_state_peer", plan.lower())
|
||||
|
||||
def test_query_plan_notifications_unread(self):
|
||||
"""Notification unread filter should use the is_viewed index."""
|
||||
print("\n[Query Plan] Notifications unread:")
|
||||
rows = self.db.provider.fetchall(
|
||||
"EXPLAIN QUERY PLAN SELECT * FROM notifications WHERE is_viewed = 0 ORDER BY timestamp DESC LIMIT 50",
|
||||
)
|
||||
plan = " ".join(str(r["detail"]) for r in rows)
|
||||
print(f" {plan}")
|
||||
self.assertIn("idx_notifications_is_viewed", plan.lower())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -9,7 +9,6 @@ Covers:
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from hypothesis import HealthCheck, given, settings
|
||||
@@ -24,6 +23,7 @@ from meshchatx.src.backend.database.schema import DatabaseSchema, _validate_iden
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_env(tmp_path):
|
||||
"""Provide an initialized DatabaseProvider + Schema in a temp directory."""
|
||||
@@ -52,8 +52,8 @@ def _make_legacy_db(legacy_dir, identity_hash, tables_sql):
|
||||
# 1. ATTACH DATABASE — single-quote escaping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAttachDatabasePathEscaping:
|
||||
|
||||
class TestAttachDatabasePathEscaping:
|
||||
def test_path_without_quotes_migrates_normally(self, db_env):
|
||||
provider, _schema, tmp_path = db_env
|
||||
legacy_dir = str(tmp_path / "legacy_normal")
|
||||
@@ -153,8 +153,8 @@ class TestAttachDatabasePathEscaping:
|
||||
# 2. Legacy migrator — malicious column names filtered out
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLegacyColumnFiltering:
|
||||
|
||||
class TestLegacyColumnFiltering:
|
||||
def test_normal_columns_migrate(self, db_env):
|
||||
"""Standard column names pass through the identifier filter."""
|
||||
provider, _schema, tmp_path = db_env
|
||||
@@ -188,9 +188,7 @@ class TestLegacyColumnFiltering:
|
||||
conn.execute(
|
||||
'CREATE TABLE config (key TEXT UNIQUE, value TEXT, "key; DROP TABLE config" TEXT)'
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO config (key, value) VALUES ('safe', 'data')"
|
||||
)
|
||||
conn.execute("INSERT INTO config (key, value) VALUES ('safe', 'data')")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@@ -211,12 +209,8 @@ class TestLegacyColumnFiltering:
|
||||
os.makedirs(identity_dir, exist_ok=True)
|
||||
db_path = os.path.join(identity_dir, "database.db")
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(
|
||||
'CREATE TABLE config (key TEXT UNIQUE, value TEXT, "evil()" TEXT)'
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO config (key, value) VALUES ('p1', 'parens')"
|
||||
)
|
||||
conn.execute('CREATE TABLE config (key TEXT UNIQUE, value TEXT, "evil()" TEXT)')
|
||||
conn.execute("INSERT INTO config (key, value) VALUES ('p1', 'parens')")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@@ -231,8 +225,8 @@ class TestLegacyColumnFiltering:
|
||||
# 3. _validate_identifier — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateIdentifier:
|
||||
|
||||
class TestValidateIdentifier:
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
[
|
||||
@@ -273,8 +267,8 @@ class TestValidateIdentifier:
|
||||
# 4. _ensure_column — rejects injection via table/column names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnsureColumnInjection:
|
||||
|
||||
class TestEnsureColumnInjection:
|
||||
def test_ensure_column_rejects_malicious_table_name(self, db_env):
|
||||
_provider, schema, _tmp_path = db_env
|
||||
with pytest.raises(ValueError, match="Invalid SQL table name"):
|
||||
@@ -354,6 +348,7 @@ def test_validate_identifier_rejects_pure_metacharacter_strings(name):
|
||||
# 6. ATTACH path escaping — property-based
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@given(
|
||||
path_segment=st.text(
|
||||
alphabet=st.characters(
|
||||
|
||||
Reference in New Issue
Block a user