mirror of
https://git.quad4.io/RNS-Things/MeshChatX.git
synced 2026-05-11 07:26:53 +00:00
refactor(database): remove LegacyMigrator and associated tests to streamline database migration process
This commit is contained in:
@@ -1,157 +0,0 @@
|
||||
# SPDX-License-Identifier: 0BSD
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
class LegacyMigrator:
|
||||
def __init__(self, provider, reticulum_config_dir, identity_hash_hex):
|
||||
self.provider = provider
|
||||
self.reticulum_config_dir = reticulum_config_dir
|
||||
self.identity_hash_hex = identity_hash_hex
|
||||
|
||||
def get_legacy_db_path(self):
|
||||
"""Detect the path to the legacy database based on the Reticulum config directory."""
|
||||
possible_dirs = []
|
||||
if self.reticulum_config_dir:
|
||||
possible_dirs.append(self.reticulum_config_dir)
|
||||
|
||||
# Add common default locations
|
||||
home = os.path.expanduser("~")
|
||||
possible_dirs.append(os.path.join(home, ".reticulum-meshchat"))
|
||||
possible_dirs.append(os.path.join(home, ".reticulum"))
|
||||
|
||||
# Check each directory
|
||||
for config_dir in possible_dirs:
|
||||
legacy_path = os.path.join(
|
||||
config_dir,
|
||||
"identities",
|
||||
self.identity_hash_hex,
|
||||
"database.db",
|
||||
)
|
||||
if os.path.exists(legacy_path):
|
||||
# Ensure it's not the same as our current DB path
|
||||
# (though this is unlikely given the different base directories)
|
||||
try:
|
||||
current_db_path = os.path.abspath(self.provider.db_path)
|
||||
if os.path.abspath(legacy_path) == current_db_path:
|
||||
continue
|
||||
except (AttributeError, OSError):
|
||||
# If we can't get the absolute path, just skip this check
|
||||
pass
|
||||
return legacy_path
|
||||
|
||||
return None
|
||||
|
||||
def should_migrate(self):
|
||||
"""Return whether migration should run.
|
||||
|
||||
Only migrates when the current database is empty and a legacy DB exists.
|
||||
"""
|
||||
legacy_path = self.get_legacy_db_path()
|
||||
if not legacy_path:
|
||||
return False
|
||||
|
||||
# Check if current DB has any messages
|
||||
try:
|
||||
res = self.provider.fetchone("SELECT COUNT(*) as count FROM lxmf_messages")
|
||||
if res and res["count"] > 0:
|
||||
# Already have data, don't auto-migrate
|
||||
return False
|
||||
except Exception:
|
||||
# Table doesn't exist yet, which is fine
|
||||
# We use a broad Exception here as the database might not even be initialized
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
def migrate(self):
|
||||
"""Perform the migration from the legacy database."""
|
||||
legacy_path = self.get_legacy_db_path()
|
||||
if not legacy_path:
|
||||
return False
|
||||
|
||||
print(f"Detecting legacy database at {legacy_path}...")
|
||||
|
||||
try:
|
||||
# Attach the legacy database
|
||||
# We use a randomized alias to avoid collisions
|
||||
alias = f"legacy_{os.urandom(4).hex()}"
|
||||
safe_path = legacy_path.replace("'", "''")
|
||||
self.provider.execute(f"ATTACH DATABASE '{safe_path}' AS {alias}")
|
||||
|
||||
# Tables that existed in the legacy Peewee version
|
||||
tables_to_migrate = [
|
||||
"announces",
|
||||
"blocked_destinations",
|
||||
"config",
|
||||
"custom_destination_display_names",
|
||||
"favourite_destinations",
|
||||
"lxmf_conversation_read_state",
|
||||
"lxmf_messages",
|
||||
"lxmf_user_icons",
|
||||
"spam_keywords",
|
||||
]
|
||||
|
||||
print("Auto-migrating data from legacy database...")
|
||||
for table in tables_to_migrate:
|
||||
# Basic validation to ensure table name is from our whitelist
|
||||
if table not in tables_to_migrate:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Check if table exists in legacy DB
|
||||
# We use a f-string here for the alias and table name, which are controlled by us
|
||||
check_query = f"SELECT name FROM {alias}.sqlite_master WHERE type='table' AND name=?"
|
||||
res = self.provider.fetchone(check_query, (table,))
|
||||
|
||||
if res:
|
||||
# Get columns from both databases to ensure compatibility
|
||||
# These PRAGMA calls are safe as they use controlled table/alias names
|
||||
legacy_columns = [
|
||||
row["name"]
|
||||
for row in self.provider.fetchall(
|
||||
f"PRAGMA {alias}.table_info({table})",
|
||||
)
|
||||
]
|
||||
current_columns = [
|
||||
row["name"]
|
||||
for row in self.provider.fetchall(
|
||||
f"PRAGMA table_info({table})",
|
||||
)
|
||||
]
|
||||
|
||||
# Find common columns, but exclude 'id' to avoid collisions during migration
|
||||
# as new databases will have their own autoincrement IDs.
|
||||
common_columns = [
|
||||
col
|
||||
for col in legacy_columns
|
||||
if col in current_columns
|
||||
and col.lower() != "id"
|
||||
and _IDENTIFIER_RE.match(col)
|
||||
]
|
||||
|
||||
if common_columns:
|
||||
cols_str = ", ".join(common_columns)
|
||||
# We use INSERT OR IGNORE to avoid duplicates
|
||||
# The table and columns are controlled by us
|
||||
migrate_query = f"INSERT OR IGNORE INTO {table} ({cols_str}) SELECT {cols_str} FROM {alias}.{table}"
|
||||
self.provider.execute(migrate_query)
|
||||
print(
|
||||
f" - Migrated table: {table} ({len(common_columns)} columns)",
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" - Skipping table {table}: No common columns found",
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" - Failed to migrate table {table}: {e}")
|
||||
|
||||
self.provider.execute(f"DETACH DATABASE {alias}")
|
||||
print("Legacy migration completed successfully.")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Migration from legacy failed: {e}")
|
||||
return False
|
||||
@@ -1,168 +0,0 @@
|
||||
# SPDX-License-Identifier: 0BSD
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from meshchatx.src.backend.database import Database
|
||||
from meshchatx.src.backend.database.legacy_migrator import LegacyMigrator
|
||||
from meshchatx.src.backend.database.provider import DatabaseProvider
|
||||
|
||||
|
||||
class TestDatabaseMigration(unittest.TestCase):
|
||||
def setUp(self):
|
||||
DatabaseProvider._instance = None
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
# Legacy migrator expects a specific structure: reticulum_config_dir/identities/identity_hash_hex/database.db
|
||||
self.identity_hash = "deadbeef"
|
||||
self.legacy_config_dir = os.path.join(self.test_dir, "legacy_config")
|
||||
self.legacy_db_subdir = os.path.join(
|
||||
self.legacy_config_dir,
|
||||
"identities",
|
||||
self.identity_hash,
|
||||
)
|
||||
os.makedirs(self.legacy_db_subdir, exist_ok=True)
|
||||
self.legacy_db_path = os.path.join(self.legacy_db_subdir, "database.db")
|
||||
|
||||
# Create legacy database with 1.x/2.x schema
|
||||
self.create_legacy_db(self.legacy_db_path)
|
||||
|
||||
# Current database
|
||||
self.current_db_path = os.path.join(self.test_dir, "current.db")
|
||||
self.db = Database(self.current_db_path)
|
||||
self.db.initialize()
|
||||
|
||||
def tearDown(self):
|
||||
self.db.close_all()
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def create_legacy_db(self, path):
|
||||
conn = sqlite3.connect(path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Based on liamcottle/reticulum-meshchat/database.py
|
||||
cursor.execute("""
|
||||
CREATE TABLE config (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
key TEXT UNIQUE,
|
||||
value TEXT,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE announces (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
destination_hash TEXT UNIQUE,
|
||||
aspect TEXT,
|
||||
identity_hash TEXT,
|
||||
identity_public_key TEXT,
|
||||
app_data TEXT,
|
||||
rssi INTEGER,
|
||||
snr REAL,
|
||||
quality REAL,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE lxmf_messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
hash TEXT UNIQUE,
|
||||
source_hash TEXT,
|
||||
destination_hash TEXT,
|
||||
state TEXT,
|
||||
progress REAL,
|
||||
is_incoming INTEGER,
|
||||
method TEXT,
|
||||
delivery_attempts INTEGER,
|
||||
next_delivery_attempt_at REAL,
|
||||
title TEXT,
|
||||
content TEXT,
|
||||
fields TEXT,
|
||||
timestamp REAL,
|
||||
rssi INTEGER,
|
||||
snr REAL,
|
||||
quality REAL,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME
|
||||
)
|
||||
""")
|
||||
|
||||
# Insert some legacy data
|
||||
cursor.execute(
|
||||
"INSERT INTO config (key, value) VALUES (?, ?)",
|
||||
("legacy_key", "legacy_value"),
|
||||
)
|
||||
cursor.execute(
|
||||
"INSERT INTO announces (destination_hash, aspect, identity_hash) VALUES (?, ?, ?)",
|
||||
("dest1", "lxmf.delivery", "id1"),
|
||||
)
|
||||
cursor.execute(
|
||||
"INSERT INTO lxmf_messages (hash, source_hash, destination_hash, title, content, fields, is_incoming, state, progress, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
"msg1",
|
||||
"src1",
|
||||
"dest1",
|
||||
"Old Title",
|
||||
"Old Content",
|
||||
"{}",
|
||||
1,
|
||||
"delivered",
|
||||
1.0,
|
||||
123456789.0,
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def test_migration_evolution(self):
|
||||
migrator = LegacyMigrator(
|
||||
self.db.provider,
|
||||
self.legacy_config_dir,
|
||||
self.identity_hash,
|
||||
)
|
||||
|
||||
# Check if should migrate
|
||||
self.assertTrue(
|
||||
migrator.should_migrate(),
|
||||
"Should detect legacy database for migration",
|
||||
)
|
||||
|
||||
# Perform migration
|
||||
success = migrator.migrate()
|
||||
self.assertTrue(success, "Migration should complete successfully")
|
||||
|
||||
# Verify data in current database
|
||||
config_rows = self.db.provider.fetchall("SELECT * FROM config")
|
||||
print(f"Config rows: {config_rows}")
|
||||
|
||||
config_val = self.db.provider.fetchone(
|
||||
"SELECT value FROM config WHERE key = ?",
|
||||
("legacy_key",),
|
||||
)
|
||||
self.assertIsNotNone(config_val, "legacy_key should have been migrated")
|
||||
self.assertEqual(config_val["value"], "legacy_value")
|
||||
|
||||
ann_count = self.db.provider.fetchone(
|
||||
"SELECT COUNT(*) as count FROM announces",
|
||||
)["count"]
|
||||
self.assertEqual(ann_count, 1)
|
||||
|
||||
msg = self.db.provider.fetchone(
|
||||
"SELECT * FROM lxmf_messages WHERE hash = ?",
|
||||
("msg1",),
|
||||
)
|
||||
self.assertIsNotNone(msg)
|
||||
self.assertEqual(msg["title"], "Old Title")
|
||||
self.assertEqual(msg["content"], "Old Content")
|
||||
self.assertEqual(msg["source_hash"], "src1")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,384 +0,0 @@
|
||||
# SPDX-License-Identifier: 0BSD
|
||||
|
||||
"""Legacy migration and schema: safe ATTACH paths, identifiers, and raw SQL helpers.
|
||||
|
||||
Covers ATTACH DATABASE path escaping in LegacyMigrator, column identifier filtering,
|
||||
and DatabaseSchema _validate_identifier / _ensure_column behaviour.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
|
||||
import pytest
|
||||
from hypothesis import HealthCheck, given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from meshchatx.src.backend.database.legacy_migrator import LegacyMigrator
|
||||
from meshchatx.src.backend.database.provider import DatabaseProvider
|
||||
from meshchatx.src.backend.database.schema import DatabaseSchema, _validate_identifier
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_env(tmp_path):
|
||||
"""Provide an initialized DatabaseProvider + Schema in a temp directory."""
|
||||
db_path = str(tmp_path / "current.db")
|
||||
provider = DatabaseProvider(db_path)
|
||||
schema = DatabaseSchema(provider)
|
||||
schema.initialize()
|
||||
yield provider, schema, tmp_path
|
||||
provider.close()
|
||||
|
||||
|
||||
def _make_legacy_db(legacy_dir, identity_hash, tables_sql):
|
||||
"""Create a legacy database with the given CREATE TABLE + INSERT statements."""
|
||||
identity_dir = os.path.join(legacy_dir, "identities", identity_hash)
|
||||
os.makedirs(identity_dir, exist_ok=True)
|
||||
db_path = os.path.join(identity_dir, "database.db")
|
||||
conn = sqlite3.connect(db_path)
|
||||
for sql in tables_sql:
|
||||
conn.execute(sql)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. ATTACH DATABASE — single-quote escaping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAttachDatabasePathEscaping:
|
||||
def test_path_without_quotes_migrates_normally(self, db_env):
|
||||
provider, _schema, tmp_path = db_env
|
||||
legacy_dir = str(tmp_path / "legacy_normal")
|
||||
identity_hash = "aabbccdd"
|
||||
_make_legacy_db(
|
||||
legacy_dir,
|
||||
identity_hash,
|
||||
[
|
||||
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
|
||||
"INSERT INTO config (key, value) VALUES ('k1', 'v1')",
|
||||
],
|
||||
)
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
assert migrator.migrate() is True
|
||||
|
||||
row = provider.fetchone("SELECT value FROM config WHERE key = 'k1'")
|
||||
assert row is not None
|
||||
assert row["value"] == "v1"
|
||||
|
||||
def test_path_with_single_quote_does_not_crash(self, db_env):
|
||||
"""A path containing a single quote must not cause SQL injection or crash."""
|
||||
provider, _schema, tmp_path = db_env
|
||||
|
||||
quoted_dir = tmp_path / "it's_a_test"
|
||||
quoted_dir.mkdir(parents=True, exist_ok=True)
|
||||
legacy_dir = str(quoted_dir)
|
||||
identity_hash = "aabbccdd"
|
||||
_make_legacy_db(
|
||||
legacy_dir,
|
||||
identity_hash,
|
||||
[
|
||||
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
|
||||
"INSERT INTO config (key, value) VALUES ('q1', 'quoted_val')",
|
||||
],
|
||||
)
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
result = migrator.migrate()
|
||||
assert result is True
|
||||
|
||||
row = provider.fetchone("SELECT value FROM config WHERE key = 'q1'")
|
||||
assert row is not None
|
||||
assert row["value"] == "quoted_val"
|
||||
|
||||
def test_path_with_multiple_quotes(self, db_env):
|
||||
"""Multiple single quotes in the path are all escaped."""
|
||||
provider, _schema, tmp_path = db_env
|
||||
|
||||
weird_dir = tmp_path / "a'b'c"
|
||||
weird_dir.mkdir(parents=True, exist_ok=True)
|
||||
legacy_dir = str(weird_dir)
|
||||
identity_hash = "11223344"
|
||||
_make_legacy_db(
|
||||
legacy_dir,
|
||||
identity_hash,
|
||||
[
|
||||
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
|
||||
"INSERT INTO config (key, value) VALUES ('mq', 'multi_quote')",
|
||||
],
|
||||
)
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
assert migrator.migrate() is True
|
||||
|
||||
row = provider.fetchone("SELECT value FROM config WHERE key = 'mq'")
|
||||
assert row is not None
|
||||
assert row["value"] == "multi_quote"
|
||||
|
||||
def test_path_with_sql_injection_attempt(self, db_env):
|
||||
"""A path crafted to look like SQL injection is safely escaped."""
|
||||
provider, _schema, tmp_path = db_env
|
||||
|
||||
evil_dir = tmp_path / "'; DROP TABLE config; --"
|
||||
evil_dir.mkdir(parents=True, exist_ok=True)
|
||||
legacy_dir = str(evil_dir)
|
||||
identity_hash = "deadbeef"
|
||||
_make_legacy_db(
|
||||
legacy_dir,
|
||||
identity_hash,
|
||||
[
|
||||
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
|
||||
"INSERT INTO config (key, value) VALUES ('evil', 'nope')",
|
||||
],
|
||||
)
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
migrator.migrate()
|
||||
|
||||
tables = provider.fetchall(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='config'",
|
||||
)
|
||||
assert len(tables) > 0, "config table must still exist after injection attempt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Legacy migrator — malicious column names filtered out
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLegacyColumnFiltering:
|
||||
def test_normal_columns_migrate(self, db_env):
|
||||
"""Standard column names pass through the identifier filter."""
|
||||
provider, _schema, tmp_path = db_env
|
||||
legacy_dir = str(tmp_path / "legacy_cols")
|
||||
identity_hash = "aabb0011"
|
||||
_make_legacy_db(
|
||||
legacy_dir,
|
||||
identity_hash,
|
||||
[
|
||||
"CREATE TABLE config (key TEXT UNIQUE, value TEXT)",
|
||||
"INSERT INTO config (key, value) VALUES ('c1', 'ok')",
|
||||
],
|
||||
)
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
assert migrator.migrate() is True
|
||||
|
||||
row = provider.fetchone("SELECT value FROM config WHERE key = 'c1'")
|
||||
assert row is not None
|
||||
|
||||
def test_malicious_column_name_is_skipped(self, db_env):
|
||||
"""A column with SQL metacharacters in its name must be silently skipped."""
|
||||
provider, _schema, tmp_path = db_env
|
||||
legacy_dir = str(tmp_path / "legacy_evil_col")
|
||||
identity_hash = "cc00dd00"
|
||||
|
||||
identity_dir = os.path.join(legacy_dir, "identities", identity_hash)
|
||||
os.makedirs(identity_dir, exist_ok=True)
|
||||
db_path = os.path.join(identity_dir, "database.db")
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(
|
||||
'CREATE TABLE config (key TEXT UNIQUE, value TEXT, "key; DROP TABLE config" TEXT)',
|
||||
)
|
||||
conn.execute("INSERT INTO config (key, value) VALUES ('safe', 'data')")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
assert migrator.migrate() is True
|
||||
|
||||
row = provider.fetchone("SELECT value FROM config WHERE key = 'safe'")
|
||||
assert row is not None
|
||||
assert row["value"] == "data"
|
||||
|
||||
def test_column_with_parentheses_is_skipped(self, db_env):
|
||||
"""Columns with () in the name are rejected by the identifier regex."""
|
||||
provider, _schema, tmp_path = db_env
|
||||
legacy_dir = str(tmp_path / "legacy_parens_col")
|
||||
identity_hash = "ee00ff00"
|
||||
|
||||
identity_dir = os.path.join(legacy_dir, "identities", identity_hash)
|
||||
os.makedirs(identity_dir, exist_ok=True)
|
||||
db_path = os.path.join(identity_dir, "database.db")
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute('CREATE TABLE config (key TEXT UNIQUE, value TEXT, "evil()" TEXT)')
|
||||
conn.execute("INSERT INTO config (key, value) VALUES ('p1', 'parens')")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
assert migrator.migrate() is True
|
||||
|
||||
row = provider.fetchone("SELECT value FROM config WHERE key = 'p1'")
|
||||
assert row is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. _validate_identifier — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateIdentifier:
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
[
|
||||
"config",
|
||||
"lxmf_messages",
|
||||
"A",
|
||||
"_private",
|
||||
"Column123",
|
||||
"a_b_c_d",
|
||||
],
|
||||
)
|
||||
def test_valid_identifiers_pass(self, name):
|
||||
assert _validate_identifier(name) == name
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
[
|
||||
"",
|
||||
"123abc",
|
||||
"table name",
|
||||
"col;drop",
|
||||
"a'b",
|
||||
'a"b',
|
||||
"col()",
|
||||
"x--y",
|
||||
"a,b",
|
||||
"hello\nworld",
|
||||
"tab\there",
|
||||
"col/**/name",
|
||||
],
|
||||
)
|
||||
def test_invalid_identifiers_raise(self, name):
|
||||
with pytest.raises(ValueError, match="Invalid SQL"):
|
||||
_validate_identifier(name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. _ensure_column — rejects injection via table/column names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnsureColumnInjection:
|
||||
def test_ensure_column_rejects_malicious_table_name(self, db_env):
|
||||
_provider, schema, _tmp_path = db_env
|
||||
with pytest.raises(ValueError, match="Invalid SQL table name"):
|
||||
schema._ensure_column("config; DROP TABLE config", "new_col", "TEXT")
|
||||
|
||||
def test_ensure_column_rejects_malicious_column_name(self, db_env):
|
||||
_provider, schema, _tmp_path = db_env
|
||||
with pytest.raises(ValueError, match="Invalid SQL column name"):
|
||||
schema._ensure_column("config", "col; DROP TABLE config", "TEXT")
|
||||
|
||||
def test_ensure_column_works_for_valid_names(self, db_env):
|
||||
_provider, schema, _tmp_path = db_env
|
||||
result = schema._ensure_column("config", "test_new_col", "TEXT")
|
||||
assert result is True
|
||||
|
||||
def test_ensure_column_idempotent(self, db_env):
|
||||
_provider, schema, _tmp_path = db_env
|
||||
schema._ensure_column("config", "idempotent_col", "TEXT")
|
||||
result = schema._ensure_column("config", "idempotent_col", "TEXT")
|
||||
assert result is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Property-based tests — identifier regex
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
@given(name=st.text(min_size=1, max_size=80))
|
||||
@settings(deadline=None)
|
||||
def test_validate_identifier_never_allows_sql_metacharacters(name):
|
||||
"""No string accepted by _validate_identifier contains SQL metacharacters."""
|
||||
try:
|
||||
_validate_identifier(name)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
assert ";" not in name
|
||||
assert "'" not in name
|
||||
assert '"' not in name
|
||||
assert "(" not in name
|
||||
assert ")" not in name
|
||||
assert " " not in name
|
||||
assert "-" not in name
|
||||
assert "/" not in name
|
||||
assert "\\" not in name
|
||||
assert "\n" not in name
|
||||
assert "\r" not in name
|
||||
assert "\t" not in name
|
||||
assert "," not in name
|
||||
assert _IDENTIFIER_RE.match(name)
|
||||
|
||||
|
||||
@given(name=st.from_regex(r"[A-Za-z_][A-Za-z0-9_]{0,30}", fullmatch=True))
|
||||
@settings(deadline=None)
|
||||
def test_validate_identifier_accepts_all_valid_identifiers(name):
|
||||
"""Every string matching the identifier pattern is accepted."""
|
||||
assert _validate_identifier(name) == name
|
||||
|
||||
|
||||
@given(
|
||||
name=st.text(
|
||||
alphabet=st.sampled_from(list(";'\"()- \t\n\r,/*")),
|
||||
min_size=1,
|
||||
max_size=30,
|
||||
),
|
||||
)
|
||||
@settings(deadline=None)
|
||||
def test_validate_identifier_rejects_pure_metacharacter_strings(name):
|
||||
"""Strings composed entirely of SQL metacharacters are always rejected."""
|
||||
with pytest.raises(ValueError):
|
||||
_validate_identifier(name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. ATTACH path escaping — property-based
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@given(
|
||||
path_segment=st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "P", "S", "Z"),
|
||||
),
|
||||
min_size=1,
|
||||
max_size=60,
|
||||
),
|
||||
)
|
||||
@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
|
||||
def test_attach_path_escaping_never_breaks_sql(path_segment):
|
||||
"""Quote-doubled ATTACH paths stay inside the SQL string literal."""
|
||||
safe = path_segment.replace("'", "''")
|
||||
sql = f"ATTACH DATABASE '{safe}' AS test_alias"
|
||||
|
||||
assert sql.count("ATTACH DATABASE '") == 1
|
||||
|
||||
after_open = sql.split("ATTACH DATABASE '", 1)[1]
|
||||
in_literal = True
|
||||
i = 0
|
||||
while i < len(after_open):
|
||||
if after_open[i] == "'":
|
||||
if i + 1 < len(after_open) and after_open[i + 1] == "'":
|
||||
i += 2
|
||||
continue
|
||||
in_literal = False
|
||||
remainder = after_open[i + 1 :]
|
||||
break
|
||||
i += 1
|
||||
|
||||
if not in_literal:
|
||||
assert remainder.strip() == "AS test_alias", (
|
||||
f"Unexpected SQL after literal end: {remainder!r}"
|
||||
)
|
||||
Reference in New Issue
Block a user