Files
meshcore-bot/tests/test_db_manager.py
agessaman 217d2a4089 Refactor database connection handling across multiple modules
- Replaced direct SQLite connection calls with a context manager in various modules to ensure proper resource management and prevent file descriptor leaks.
- Introduced a new `connection` method in `DBManager` to standardize connection handling.
- Updated all relevant database interactions in modules such as `feed_manager`, `scheduler`, `commands`, and others to utilize the new connection method.
- Improved code readability and maintainability by consolidating connection logic.
2026-03-01 14:12:22 -08:00

184 lines
6.4 KiB
Python

"""Tests for modules.db_manager."""
import sqlite3
import json
from contextlib import closing
import pytest
from unittest.mock import Mock
from modules.db_manager import DBManager
@pytest.fixture
def db(mock_logger, tmp_path):
"""File-based DBManager for testing. _init_database() auto-creates core tables."""
bot = Mock()
bot.logger = mock_logger
return DBManager(bot, str(tmp_path / "test.db"))
class TestGeocoding:
"""Tests for geocoding cache."""
def test_cache_and_retrieve_geocoding(self, db):
db.cache_geocoding("Seattle, WA", 47.6062, -122.3321)
lat, lon = db.get_cached_geocoding("Seattle, WA")
assert abs(lat - 47.6062) < 0.001
assert abs(lon - (-122.3321)) < 0.001
def test_get_cached_geocoding_miss(self, db):
lat, lon = db.get_cached_geocoding("Nonexistent City")
assert lat is None
assert lon is None
def test_cache_geocoding_overwrites_existing(self, db):
db.cache_geocoding("Test", 10.0, 20.0)
db.cache_geocoding("Test", 30.0, 40.0)
lat, lon = db.get_cached_geocoding("Test")
assert abs(lat - 30.0) < 0.001
assert abs(lon - 40.0) < 0.001
def test_cache_geocoding_invalid_hours_logged(self, db):
"""Invalid cache_hours is caught and logged, not raised."""
db.cache_geocoding("Test", 10.0, 20.0, cache_hours=0)
db.bot.logger.error.assert_called()
# Verify it did not store anything
lat, lon = db.get_cached_geocoding("Test")
assert lat is None
class TestGenericCache:
"""Tests for generic cache."""
def test_cache_and_retrieve_value(self, db):
db.cache_value("weather_key", "sunny", "weather")
result = db.get_cached_value("weather_key", "weather")
assert result == "sunny"
def test_get_cached_value_miss(self, db):
assert db.get_cached_value("nonexistent", "any") is None
def test_different_keys_stored_independently(self, db):
db.cache_value("key_a", "value_a", "weather")
db.cache_value("key_b", "value_b", "weather")
assert db.get_cached_value("key_a", "weather") == "value_a"
assert db.get_cached_value("key_b", "weather") == "value_b"
def test_cache_json_round_trip(self, db):
data = {"temp": 72, "conditions": "clear", "nested": {"wind": 5}}
db.cache_json("forecast", data, "weather")
result = db.get_cached_json("forecast", "weather")
assert result == data
def test_get_cached_json_invalid_json(self, db):
"""Manually insert invalid JSON; get_cached_json returns None."""
with closing(sqlite3.connect(str(db.db_path))) as conn:
conn.execute(
"INSERT INTO generic_cache (cache_key, cache_value, cache_type, expires_at) "
"VALUES (?, ?, ?, datetime('now', '+24 hours'))",
("bad_json", "not{valid}json", "test"),
)
conn.commit()
assert db.get_cached_json("bad_json", "test") is None
class TestCacheCleanup:
"""Tests for cache expiry cleanup."""
def test_cleanup_expired_deletes_old(self, db):
db.cache_value("old_key", "old_val", "test")
# Manually set expires_at to the past
with closing(sqlite3.connect(str(db.db_path))) as conn:
conn.execute(
"UPDATE generic_cache SET expires_at = datetime('now', '-1 hours') "
"WHERE cache_key = 'old_key'"
)
conn.commit()
db.cleanup_expired_cache()
assert db.get_cached_value("old_key", "test") is None
def test_cleanup_expired_preserves_valid(self, db):
db.cache_value("fresh_key", "fresh_val", "test", cache_hours=720)
db.cleanup_expired_cache()
assert db.get_cached_value("fresh_key", "test") == "fresh_val"
class TestTableManagement:
"""Tests for table creation whitelist."""
def test_create_table_allowed(self, db):
db.create_table(
"greeted_users",
"id INTEGER PRIMARY KEY, name TEXT NOT NULL",
)
with closing(sqlite3.connect(str(db.db_path))) as conn:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='greeted_users'"
)
assert cursor.fetchone() is not None
def test_create_table_disallowed_raises(self, db):
with pytest.raises(ValueError, match="not in allowed tables"):
db.create_table("not_allowed", "id INTEGER PRIMARY KEY")
def test_create_table_sql_injection_name_raises(self, db):
with pytest.raises(ValueError):
db.create_table("DROP TABLE users; --", "id INTEGER PRIMARY KEY")
class TestExecuteQuery:
"""Tests for raw query execution."""
def test_execute_query_returns_dicts(self, db):
db.set_metadata("test_key", "test_value")
rows = db.execute_query("SELECT * FROM bot_metadata WHERE key = ?", ("test_key",))
assert len(rows) == 1
assert rows[0]["key"] == "test_key"
assert rows[0]["value"] == "test_value"
def test_execute_update_returns_rowcount(self, db):
db.set_metadata("del_key", "del_value")
count = db.execute_update(
"DELETE FROM bot_metadata WHERE key = ?", ("del_key",)
)
assert count == 1
class TestMetadata:
"""Tests for bot metadata storage."""
def test_set_and_get_metadata(self, db):
db.set_metadata("version", "1.2.3")
assert db.get_metadata("version") == "1.2.3"
def test_get_metadata_miss(self, db):
assert db.get_metadata("nonexistent") is None
def test_bot_start_time_round_trip(self, db):
ts = 1234567890.5
db.set_bot_start_time(ts)
assert db.get_bot_start_time() == ts
class TestCacheHoursValidation:
"""Tests for cache_hours boundary validation."""
def test_boundary_values(self, db):
# Valid boundaries
db.cache_value("k1", "v1", "t", cache_hours=1)
assert db.get_cached_value("k1", "t") == "v1"
db.cache_value("k2", "v2", "t", cache_hours=87600)
assert db.get_cached_value("k2", "t") == "v2"
# Invalid boundaries — caught and logged, not stored
db.cache_value("k3", "v3", "t", cache_hours=0)
db.bot.logger.error.assert_called()
assert db.get_cached_value("k3", "t") is None
db.bot.logger.error.reset_mock()
db.cache_value("k4", "v4", "t", cache_hours=87601)
db.bot.logger.error.assert_called()
assert db.get_cached_value("k4", "t") is None