mirror of
https://git.quad4.io/RNS-Things/MeshChatX.git
synced 2026-05-11 18:14:51 +00:00
feat(community): implement URL validation and fetch handling for community directory to prevent SSRF vulnerabilities; add tests for validation and fetch behavior
This commit is contained in:
@@ -9,6 +9,7 @@ import re
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
DEFAULT_SUBMITTED_URL = (
|
||||
"https://directory.rns.recipes/api/directory/submitted?search=&type=&status=online"
|
||||
@@ -16,6 +17,44 @@ DEFAULT_SUBMITTED_URL = (
|
||||
|
||||
DESCRIPTION = "directory.rns.recipes (user-submitted, online)"
|
||||
|
||||
_ALLOWED_DIRECTORY_HOSTS = frozenset({"directory.rns.recipes"})
|
||||
|
||||
|
||||
def validate_directory_fetch_url(url: str) -> str:
|
||||
"""Reject SSRF: only https to directory.rns.recipes, no credentials."""
|
||||
if not url or not isinstance(url, str):
|
||||
msg = "URL must be a non-empty string"
|
||||
raise ValueError(msg)
|
||||
parsed = urlparse(url.strip())
|
||||
if parsed.scheme != "https":
|
||||
msg = "Community directory URL must use https"
|
||||
raise ValueError(msg)
|
||||
netloc = parsed.netloc or ""
|
||||
if "@" in netloc:
|
||||
msg = "Community directory URL must not contain credentials"
|
||||
raise ValueError(msg)
|
||||
host = (parsed.hostname or "").lower()
|
||||
if host not in _ALLOWED_DIRECTORY_HOSTS:
|
||||
msg = "Community directory URL host is not allowed"
|
||||
raise ValueError(msg)
|
||||
return url.strip()
|
||||
|
||||
|
||||
class _DirectoryFetchNoRedirectHandler(urllib.request.HTTPRedirectHandler):
|
||||
"""urllib follows redirects by default; blocked to prevent SSRF via Location."""
|
||||
|
||||
def redirect_request(self, req, fp, code, msg, headers, newurl):
|
||||
raise urllib.error.HTTPError(
|
||||
req.full_url,
|
||||
code,
|
||||
"Redirects are not followed for community directory fetch",
|
||||
headers,
|
||||
fp,
|
||||
)
|
||||
|
||||
|
||||
_DIRECTORY_FETCH_OPENER = urllib.request.build_opener(_DirectoryFetchNoRedirectHandler())
|
||||
|
||||
|
||||
def fetch_directory_payload(url: str, *, timeout: float = 60.0) -> object:
|
||||
req = urllib.request.Request(
|
||||
@@ -26,7 +65,7 @@ def fetch_directory_payload(url: str, *, timeout: float = 60.0) -> object:
|
||||
},
|
||||
method="GET",
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
with _DIRECTORY_FETCH_OPENER.open(req, timeout=timeout) as resp:
|
||||
return json.loads(resp.read().decode("utf-8"))
|
||||
|
||||
|
||||
@@ -35,7 +74,10 @@ def build_interfaces_from_directory_url(
|
||||
*,
|
||||
timeout: float = 60.0,
|
||||
) -> tuple[list[dict[str, Any]], str]:
|
||||
resolved = url or DEFAULT_SUBMITTED_URL
|
||||
if url is not None and str(url).strip():
|
||||
resolved = validate_directory_fetch_url(url)
|
||||
else:
|
||||
resolved = DEFAULT_SUBMITTED_URL
|
||||
payload = fetch_directory_payload(resolved, timeout=timeout)
|
||||
rows = rows_from_payload(payload)
|
||||
return transform_directory_rows(rows), resolved
|
||||
|
||||
@@ -10,9 +10,14 @@ _UNSAFE_PROTOCOLS = ("javascript:", "data:", "vbscript:", "file:")
|
||||
def _safe_href(url):
|
||||
if not url or not isinstance(url, str):
|
||||
return "#"
|
||||
u = url.strip().lower()
|
||||
trimmed = url.strip()
|
||||
u = trimmed.lower()
|
||||
if any(u.startswith(p) for p in _UNSAFE_PROTOCOLS):
|
||||
return "#"
|
||||
if u.startswith("//"):
|
||||
return "#"
|
||||
if trimmed.startswith("\\\\"):
|
||||
return "#"
|
||||
if any(u.startswith(p) for p in _SAFE_LINK_PREFIXES):
|
||||
return url
|
||||
if ":" in u.split("/")[0]:
|
||||
|
||||
@@ -142,27 +142,35 @@ class TranslatorHandler:
|
||||
libretranslate_reachable = False
|
||||
|
||||
url = libretranslate_url or self.libretranslate_url
|
||||
explicit_override = (
|
||||
libretranslate_url is not None and str(libretranslate_url).strip() != ""
|
||||
)
|
||||
libre_base = None
|
||||
if self.has_requests:
|
||||
if libretranslate_url is not None and str(libretranslate_url).strip():
|
||||
try:
|
||||
url = normalize_loopback_http_service_base(libretranslate_url)
|
||||
except UnsafeOutboundUrlError as e:
|
||||
try:
|
||||
libre_base = normalize_loopback_http_service_base(url)
|
||||
except UnsafeOutboundUrlError as e:
|
||||
if explicit_override:
|
||||
msg = str(e)
|
||||
raise ValueError(msg) from e
|
||||
try:
|
||||
libretranslate_langs = _sync_run_coro(self._fetch_languages_async(url))
|
||||
if libretranslate_langs is not None:
|
||||
libretranslate_reachable = True
|
||||
languages.extend(
|
||||
{
|
||||
"code": lang.get("code"),
|
||||
"name": lang.get("name"),
|
||||
"source": "libretranslate",
|
||||
}
|
||||
for lang in libretranslate_langs
|
||||
libre_base = None
|
||||
if libre_base is not None:
|
||||
try:
|
||||
libretranslate_langs = _sync_run_coro(
|
||||
self._fetch_languages_async(libre_base),
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch LibreTranslate languages: {e}")
|
||||
if libretranslate_langs is not None:
|
||||
libretranslate_reachable = True
|
||||
languages.extend(
|
||||
{
|
||||
"code": lang.get("code"),
|
||||
"name": lang.get("name"),
|
||||
"source": "libretranslate",
|
||||
}
|
||||
for lang in libretranslate_langs
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch LibreTranslate languages: {e}")
|
||||
|
||||
if self.has_argos_lib:
|
||||
try:
|
||||
@@ -222,14 +230,13 @@ class TranslatorHandler:
|
||||
return self._translate_argos(text, source_lang, target_lang)
|
||||
|
||||
if self.translator_libretranslate_enabled and self.has_requests:
|
||||
url_raw = libretranslate_url or self.libretranslate_url
|
||||
try:
|
||||
url = normalize_loopback_http_service_base(url_raw)
|
||||
except UnsafeOutboundUrlError as e:
|
||||
msg = str(e)
|
||||
raise ValueError(msg) from e
|
||||
try:
|
||||
url = libretranslate_url or self.libretranslate_url
|
||||
if libretranslate_url is not None and str(libretranslate_url).strip():
|
||||
try:
|
||||
url = normalize_loopback_http_service_base(libretranslate_url)
|
||||
except UnsafeOutboundUrlError as e:
|
||||
msg = str(e)
|
||||
raise ValueError(msg) from e
|
||||
return self._translate_libretranslate(
|
||||
text,
|
||||
source_lang=source_lang,
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
# SPDX-License-Identifier: 0BSD
|
||||
|
||||
import threading
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from meshchatx.src.backend.community_interfaces_directory import (
|
||||
DEFAULT_SUBMITTED_URL,
|
||||
fetch_directory_payload,
|
||||
rows_from_payload,
|
||||
transform_directory_rows,
|
||||
validate_directory_fetch_url,
|
||||
)
|
||||
|
||||
|
||||
@@ -16,6 +23,116 @@ def test_default_url_is_submitted_online():
|
||||
assert "status=online" in DEFAULT_SUBMITTED_URL
|
||||
|
||||
|
||||
def test_validate_directory_fetch_url_accepts_default_host():
|
||||
u = "https://directory.rns.recipes/api/foo?bar=1"
|
||||
assert validate_directory_fetch_url(u) == u
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad",
|
||||
[
|
||||
"http://directory.rns.recipes/api",
|
||||
"https://127.0.0.1/",
|
||||
"https://metadata.internal/",
|
||||
"ftp://directory.rns.recipes/",
|
||||
"https://evil.com/https://directory.rns.recipes/",
|
||||
"https://user:pass@directory.rns.recipes/",
|
||||
"https://not-directory.rns.recipes.example/api",
|
||||
],
|
||||
)
|
||||
def test_validate_directory_fetch_url_rejects_ssrf(bad):
|
||||
with pytest.raises(ValueError):
|
||||
validate_directory_fetch_url(bad)
|
||||
|
||||
|
||||
class _Redirect302Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
self.send_response(302)
|
||||
self.send_header("Location", "http://127.0.0.1:9/")
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, *args):
|
||||
return
|
||||
|
||||
|
||||
class _Json200Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
body = b'{"data":[]}'
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.send_header("Content-Length", str(len(body)))
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
|
||||
def log_message(self, *args):
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redirect_http_port():
|
||||
srv = HTTPServer(("127.0.0.1", 0), _Redirect302Handler)
|
||||
thread = threading.Thread(target=srv.serve_forever, daemon=True)
|
||||
thread.start()
|
||||
port = srv.server_address[1]
|
||||
yield port
|
||||
srv.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_http_port():
|
||||
srv = HTTPServer(("127.0.0.1", 0), _Json200Handler)
|
||||
thread = threading.Thread(target=srv.serve_forever, daemon=True)
|
||||
thread.start()
|
||||
port = srv.server_address[1]
|
||||
yield port
|
||||
srv.shutdown()
|
||||
|
||||
|
||||
def test_directory_fetch_opener_blocks_http_redirect(redirect_http_port):
|
||||
from meshchatx.src.backend.community_interfaces_directory import (
|
||||
_DIRECTORY_FETCH_OPENER,
|
||||
)
|
||||
|
||||
req = urllib.request.Request(f"http://127.0.0.1:{redirect_http_port}/")
|
||||
with pytest.raises(urllib.error.HTTPError) as ei:
|
||||
_DIRECTORY_FETCH_OPENER.open(req, timeout=3)
|
||||
assert ei.value.code == 302
|
||||
|
||||
|
||||
def test_fetch_directory_payload_reads_json_when_no_redirect(
|
||||
monkeypatch,
|
||||
json_http_port,
|
||||
):
|
||||
import meshchatx.src.backend.community_interfaces_directory as cid
|
||||
|
||||
monkeypatch.setattr(
|
||||
cid,
|
||||
"validate_directory_fetch_url",
|
||||
lambda url: url,
|
||||
)
|
||||
out = fetch_directory_payload(
|
||||
f"http://127.0.0.1:{json_http_port}/x",
|
||||
timeout=3,
|
||||
)
|
||||
assert out == {"data": []}
|
||||
|
||||
|
||||
def test_fetch_directory_payload_raises_on_redirect(monkeypatch, redirect_http_port):
|
||||
import meshchatx.src.backend.community_interfaces_directory as cid
|
||||
|
||||
monkeypatch.setattr(
|
||||
cid,
|
||||
"validate_directory_fetch_url",
|
||||
lambda url: url,
|
||||
)
|
||||
with pytest.raises(urllib.error.HTTPError) as ei:
|
||||
fetch_directory_payload(
|
||||
f"http://127.0.0.1:{redirect_http_port}/",
|
||||
timeout=3,
|
||||
)
|
||||
assert ei.value.code == 302
|
||||
|
||||
|
||||
def test_rows_from_payload_dict_data():
|
||||
rows = rows_from_payload({"data": [{"id": 1}]})
|
||||
assert rows == [{"id": 1}]
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
# SPDX-License-Identifier: 0BSD
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from meshchatx.src.backend.database import Database
|
||||
from meshchatx.src.backend.database.legacy_migrator import LegacyMigrator
|
||||
from meshchatx.src.backend.database.provider import DatabaseProvider
|
||||
from meshchatx.src.backend.database.schema import DatabaseSchema
|
||||
|
||||
@@ -43,77 +41,6 @@ def test_database_initialization(temp_db):
|
||||
provider.close()
|
||||
|
||||
|
||||
def test_legacy_migrator_detection(temp_db):
|
||||
# Setup current DB
|
||||
provider = DatabaseProvider(temp_db)
|
||||
schema = DatabaseSchema(provider)
|
||||
schema.initialize()
|
||||
|
||||
# Setup a "legacy" DB in a temp directory
|
||||
with tempfile.TemporaryDirectory() as legacy_dir:
|
||||
identity_hash = "deadbeef"
|
||||
legacy_identity_dir = os.path.join(legacy_dir, "identities", identity_hash)
|
||||
os.makedirs(legacy_identity_dir)
|
||||
legacy_db_path = os.path.join(legacy_identity_dir, "database.db")
|
||||
|
||||
legacy_conn = sqlite3.connect(legacy_db_path)
|
||||
legacy_conn.execute("CREATE TABLE config (key TEXT, value TEXT)")
|
||||
legacy_conn.execute(
|
||||
"INSERT INTO config (key, value) VALUES ('display_name', 'Legacy User')",
|
||||
)
|
||||
legacy_conn.commit()
|
||||
legacy_conn.close()
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
assert migrator.get_legacy_db_path() == legacy_db_path
|
||||
assert migrator.should_migrate() is True
|
||||
|
||||
provider.close()
|
||||
|
||||
|
||||
def test_legacy_migration_data(temp_db):
|
||||
provider = DatabaseProvider(temp_db)
|
||||
schema = DatabaseSchema(provider)
|
||||
schema.initialize()
|
||||
|
||||
with tempfile.TemporaryDirectory() as legacy_dir:
|
||||
identity_hash = "deadbeef"
|
||||
legacy_identity_dir = os.path.join(legacy_dir, "identities", identity_hash)
|
||||
os.makedirs(legacy_identity_dir)
|
||||
legacy_db_path = os.path.join(legacy_identity_dir, "database.db")
|
||||
|
||||
# Create legacy DB with some data
|
||||
legacy_conn = sqlite3.connect(legacy_db_path)
|
||||
legacy_conn.execute(
|
||||
"CREATE TABLE lxmf_messages (hash TEXT UNIQUE, content TEXT)",
|
||||
)
|
||||
legacy_conn.execute(
|
||||
"INSERT INTO lxmf_messages (hash, content) VALUES ('msg1', 'Hello Legacy')",
|
||||
)
|
||||
legacy_conn.execute("CREATE TABLE config (key TEXT UNIQUE, value TEXT)")
|
||||
legacy_conn.execute(
|
||||
"INSERT INTO config (key, value) VALUES ('test_key', 'test_val')",
|
||||
)
|
||||
legacy_conn.commit()
|
||||
legacy_conn.close()
|
||||
|
||||
migrator = LegacyMigrator(provider, legacy_dir, identity_hash)
|
||||
assert migrator.migrate() is True
|
||||
|
||||
# Verify data moved
|
||||
msg_row = provider.fetchone(
|
||||
"SELECT content FROM lxmf_messages WHERE hash = 'msg1'",
|
||||
)
|
||||
assert msg_row["content"] == "Hello Legacy"
|
||||
|
||||
config_row = provider.fetchone(
|
||||
"SELECT value FROM config WHERE key = 'test_key'",
|
||||
)
|
||||
assert config_row["value"] == "test_val"
|
||||
|
||||
provider.close()
|
||||
|
||||
|
||||
def test_database_health_snapshot_free_space(temp_db):
|
||||
db = Database(temp_db)
|
||||
db.initialize()
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: 0BSD
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from meshchatx.src.backend.database import (
|
||||
Database,
|
||||
_sanitize_pragma_read_name,
|
||||
_sanitize_wal_checkpoint_mode,
|
||||
)
|
||||
from meshchatx.src.backend.database.provider import DatabaseProvider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path():
|
||||
fd, path = tempfile.mkstemp()
|
||||
os.close(fd)
|
||||
yield path
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
wal = path + "-wal"
|
||||
shm = path + "-shm"
|
||||
for p in (wal, shm):
|
||||
if os.path.exists(p):
|
||||
os.remove(p)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_db_singleton(temp_db_path):
|
||||
DatabaseProvider._instance = None
|
||||
yield
|
||||
DatabaseProvider._instance = None
|
||||
|
||||
|
||||
def test_sanitize_pragma_read_name_accepts_known_tokens():
|
||||
assert _sanitize_pragma_read_name("journal_mode") == "journal_mode"
|
||||
assert _sanitize_pragma_read_name(" page_size ") == "page_size"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad",
|
||||
[
|
||||
"",
|
||||
" ",
|
||||
"journal mode",
|
||||
"journal-mode",
|
||||
"x;detach",
|
||||
"x'y",
|
||||
"../../../x",
|
||||
"pragma(x)",
|
||||
],
|
||||
)
|
||||
def test_sanitize_pragma_read_name_rejects_injection(bad):
|
||||
assert _sanitize_pragma_read_name(bad) is None
|
||||
|
||||
|
||||
def test_sanitize_wal_checkpoint_mode_accepts_keywords():
|
||||
assert _sanitize_wal_checkpoint_mode("truncate") == "TRUNCATE"
|
||||
assert _sanitize_wal_checkpoint_mode("PASSIVE") == "PASSIVE"
|
||||
|
||||
|
||||
def test_sanitize_wal_checkpoint_mode_rejects_injection():
|
||||
with pytest.raises(ValueError):
|
||||
_sanitize_wal_checkpoint_mode("FULL);VACUUM")
|
||||
with pytest.raises(ValueError):
|
||||
_sanitize_wal_checkpoint_mode("bogus")
|
||||
|
||||
|
||||
def test_get_pragma_value_returns_default_for_malicious_name(temp_db_path):
|
||||
db = Database(temp_db_path)
|
||||
db.initialize()
|
||||
assert db._get_pragma_value("journal_mode;evil", "fallback") == "fallback"
|
||||
jm = db._get_pragma_value("journal_mode")
|
||||
assert jm is not None and str(jm).lower() == "wal"
|
||||
|
||||
|
||||
def test_checkpoint_wal_rejects_injected_mode(temp_db_path):
|
||||
db = Database(temp_db_path)
|
||||
db.initialize()
|
||||
with pytest.raises(ValueError):
|
||||
db._checkpoint_wal("TRUNCATE);ATTACH DATABASE 'x' AS z")
|
||||
|
||||
|
||||
def test_checkpoint_wal_passive_succeeds(temp_db_path):
|
||||
db = Database(temp_db_path)
|
||||
db.initialize()
|
||||
db._checkpoint_wal("PASSIVE")
|
||||
@@ -113,10 +113,6 @@ def test_emergency_mode_startup_logic(mock_rns, temp_dir):
|
||||
mock_integrity_instance = mock_integrity_class.return_value
|
||||
assert mock_integrity_instance.check_integrity.call_count == 0
|
||||
|
||||
# Verify migrate_from_legacy was NOT called
|
||||
mock_db_instance = mock_db_class.return_value
|
||||
assert mock_db_instance.migrate_from_legacy.call_count == 0
|
||||
|
||||
# Verify TelephoneManager.init_telephone was NOT called
|
||||
mock_tel_instance = mock_tel_class.return_value
|
||||
assert mock_tel_instance.init_telephone.call_count == 0
|
||||
@@ -215,10 +211,6 @@ def test_normal_mode_startup_logic(mock_rns, temp_dir):
|
||||
# Verify IntegrityManager.check_integrity WAS called
|
||||
assert mock_integrity_instance.check_integrity.call_count == 1
|
||||
|
||||
# Verify migrate_from_legacy WAS called
|
||||
mock_db_instance = mock_db_class.return_value
|
||||
assert mock_db_instance.migrate_from_legacy.call_count == 1
|
||||
|
||||
# Verify TelephoneManager.init_telephone WAS called
|
||||
mock_tel_instance = mock_tel_class.return_value
|
||||
assert mock_tel_instance.init_telephone.call_count == 1
|
||||
|
||||
@@ -118,6 +118,21 @@ class TestMarkdownRenderer(unittest.TestCase):
|
||||
self.assertNotIn("vbscript:", r)
|
||||
self.assertIn('href="#"', r)
|
||||
|
||||
def test_protocol_relative_link_href_neutralized(self):
|
||||
r = MarkdownRenderer.render("[phish](//evil.example/phish)")
|
||||
self.assertNotIn("//evil.example", r)
|
||||
self.assertNotIn('href="//', r)
|
||||
self.assertIn('href="#"', r)
|
||||
|
||||
def test_unc_link_href_neutralized(self):
|
||||
md = "[click](" + "\\\\\\\\evil.example\\\\share" + ")"
|
||||
r = MarkdownRenderer.render(md)
|
||||
self.assertIn('href="#"', r)
|
||||
|
||||
def test_protocol_relative_image_src_neutralized(self):
|
||||
r = MarkdownRenderer.render("")
|
||||
self.assertNotIn("//evil.example", r)
|
||||
|
||||
def test_safe_links_preserved(self):
|
||||
r = MarkdownRenderer.render("[link](https://example.com/path)")
|
||||
self.assertIn('href="https://example.com/path"', r)
|
||||
|
||||
@@ -142,7 +142,6 @@ def test_reticulum_meshchat_init(mock_rns, temp_dir):
|
||||
|
||||
# Verify database initialization
|
||||
mock_db_instance.initialize.assert_called_once()
|
||||
mock_db_instance.migrate_from_legacy.assert_called_once()
|
||||
|
||||
# Verify RNS initialization
|
||||
mock_rns["Reticulum"].assert_called_once_with(temp_dir)
|
||||
|
||||
@@ -9,11 +9,11 @@ from meshchatx.src.backend.translator_handler import TranslatorHandler
|
||||
|
||||
def test_translator_handler_init():
|
||||
handler = TranslatorHandler(
|
||||
libretranslate_url="http://test:5000",
|
||||
libretranslate_url="http://127.0.0.1:5000",
|
||||
translator_argos_enabled=True,
|
||||
translator_libretranslate_enabled=True,
|
||||
)
|
||||
assert handler.libretranslate_url == "http://test:5000"
|
||||
assert handler.libretranslate_url == "http://127.0.0.1:5000"
|
||||
assert handler.translator_argos_enabled is True
|
||||
|
||||
|
||||
@@ -130,6 +130,56 @@ def test_detect_language_libretranslate(mock_session_cls):
|
||||
assert result["source_lang"] == "en"
|
||||
|
||||
|
||||
@patch("meshchatx.src.backend.translator_handler.aiohttp.ClientSession")
|
||||
def test_libretranslate_post_disallows_redirects(mock_session_cls):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(
|
||||
return_value={
|
||||
"translatedText": "y",
|
||||
"detectedLanguage": {"language": "en"},
|
||||
},
|
||||
)
|
||||
mock_post_ctx = MagicMock()
|
||||
mock_post_ctx.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_post_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(return_value=mock_post_ctx)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_session_cls.return_value = mock_session
|
||||
|
||||
handler = TranslatorHandler(translator_libretranslate_enabled=True)
|
||||
handler.has_requests = True
|
||||
handler.translate_text("Hello", source_lang="en", target_lang="fr", use_argos=False)
|
||||
|
||||
assert mock_session.post.call_args.kwargs.get("allow_redirects") is False
|
||||
|
||||
|
||||
@patch("meshchatx.src.backend.translator_handler.aiohttp.ClientSession")
|
||||
def test_libretranslate_get_languages_disallows_redirects(mock_session_cls):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value=[{"code": "en", "name": "English"}])
|
||||
mock_get_ctx = MagicMock()
|
||||
mock_get_ctx.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_get_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = MagicMock(return_value=mock_get_ctx)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_session_cls.return_value = mock_session
|
||||
|
||||
handler = TranslatorHandler()
|
||||
handler.has_argos = False
|
||||
handler.has_argos_lib = False
|
||||
handler.has_argos_cli = False
|
||||
handler.has_requests = True
|
||||
handler.get_supported_languages()
|
||||
|
||||
assert mock_session.get.call_args.kwargs.get("allow_redirects") is False
|
||||
|
||||
|
||||
def test_translator_handler_errors():
|
||||
handler = TranslatorHandler(
|
||||
translator_argos_enabled=False,
|
||||
@@ -149,3 +199,40 @@ def test_language_code_to_name():
|
||||
|
||||
assert LANGUAGE_CODE_TO_NAME["en"] == "English"
|
||||
assert LANGUAGE_CODE_TO_NAME["de"] == "German"
|
||||
|
||||
|
||||
@patch("meshchatx.src.backend.translator_handler.aiohttp.ClientSession")
|
||||
def test_get_supported_languages_skips_non_loopback_stored_url(mock_session_cls):
|
||||
handler = TranslatorHandler(
|
||||
libretranslate_url="http://169.254.169.254/",
|
||||
translator_libretranslate_enabled=True,
|
||||
)
|
||||
handler.has_requests = True
|
||||
handler.has_argos = False
|
||||
handler.has_argos_lib = False
|
||||
handler.has_argos_cli = False
|
||||
assert handler.get_supported_languages() == []
|
||||
mock_session_cls.assert_not_called()
|
||||
|
||||
|
||||
def test_translate_text_rejects_non_loopback_libre_url():
|
||||
handler = TranslatorHandler(
|
||||
libretranslate_url="http://example.com:5000/",
|
||||
translator_libretranslate_enabled=True,
|
||||
translator_argos_enabled=False,
|
||||
)
|
||||
handler.has_requests = True
|
||||
with pytest.raises(ValueError, match="URL host must be"):
|
||||
handler.translate_text("Hello", "en", "de", use_argos=False)
|
||||
|
||||
|
||||
def test_get_translator_languages_response_explicit_bad_override_raises():
|
||||
handler = TranslatorHandler(
|
||||
libretranslate_url="http://127.0.0.1:5000",
|
||||
translator_libretranslate_enabled=True,
|
||||
)
|
||||
handler.has_requests = True
|
||||
with pytest.raises(ValueError, match="URL host must be"):
|
||||
handler.get_translator_languages_response(
|
||||
libretranslate_url="http://metadata.example/latest",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user