refactor(tools): improve error handling and validation in RN path trace and RN probe handlers

This commit is contained in:
Ivan
2026-05-02 06:47:08 -05:00
parent e5cb2c7678
commit 894e63c4c3
7 changed files with 175 additions and 17 deletions
+25 -10
View File
@@ -9425,22 +9425,37 @@ class ReticulumMeshChat:
)
result = await self.rnpath_trace_handler.trace_path(destination_hash)
return web.json_response(result)
except Exception as e:
import traceback
error_msg = f"Trace route failed: {e}\n{traceback.format_exc()}"
print(error_msg)
return web.json_response({"error": error_msg}, status=500)
except Exception:
logger.exception("RN path trace route failed")
return web.json_response({"error": "Trace failed"}, status=500)
@routes.post("/api/v1/rnprobe")
async def rnprobe(request):
data = await request.json()
destination_hash_str = data.get("destination_hash", "")
full_name = data.get("full_name", "")
size = int(data.get("size", RNProbeHandler.DEFAULT_PROBE_SIZE))
timeout = float(data.get("timeout", 0)) or None
wait = float(data.get("wait", 0))
probes = int(data.get("probes", 1))
try:
size = int(data.get("size", RNProbeHandler.DEFAULT_PROBE_SIZE))
except (TypeError, ValueError):
return web.json_response({"message": "Invalid size"}, status=400)
try:
wait = float(data.get("wait", 0))
except (TypeError, ValueError):
return web.json_response({"message": "Invalid wait"}, status=400)
try:
probes = int(data.get("probes", 1))
except (TypeError, ValueError):
return web.json_response({"message": "Invalid probes"}, status=400)
timeout = None
raw_timeout = data.get("timeout", 0)
if raw_timeout is not None:
try:
t = float(raw_timeout)
except (TypeError, ValueError):
return web.json_response({"message": "Invalid timeout"}, status=400)
if t != 0:
timeout = t
try:
destination_hash = bytes.fromhex(destination_hash_str)
+3
View File
@@ -37,6 +37,7 @@ class RNCPHandler:
def teardown_receive_destination(self):
if self.receive_destination is None:
self.allowed_identity_hashes = []
return
dest = self.receive_destination
self.receive_destination = None
@@ -45,6 +46,7 @@ class RNCPHandler:
dest.deregister_request_handler("fetch_file")
self._listener_fetch_registered = False
self._listener_fetch_allowed = False
self.allowed_identity_hashes = []
with contextlib.suppress(Exception):
RNS.Transport.deregister_destination(dest)
@@ -79,6 +81,7 @@ class RNCPHandler:
):
self.teardown_receive_destination()
self.allowed_identity_hashes = []
if allowed_hashes:
self.allowed_identity_hashes = [
bytes.fromhex(h) if isinstance(h, str) else h for h in allowed_hashes
+39 -4
View File
@@ -22,6 +22,28 @@ RNGIT_IDX_REPOSITORY = 0x00
_MAX_PROBE_NAMES = 64
_MAX_REFS_PREVIEW_CHARS = 8000
_MAX_RNGIT_PATH_TIMEOUT_S = 600.0
_MAX_RNGIT_LINK_TIMEOUT_S = 180.0
_MAX_RNGIT_LIST_TIMEOUT_S = 600.0
def _clamp_timeout(
value: float | None,
*,
default: float,
minimum: float,
maximum: float,
) -> float:
if value is None:
return default
try:
v = float(value)
except (TypeError, ValueError):
return default
if v < minimum:
return minimum
return min(v, maximum)
def display_name_from_rngit_app_data(app_data_b64: str | None) -> str | None:
"""Decode optional rngit announce app_data (often a short node label)."""
@@ -102,11 +124,24 @@ async def list_remote_git_refs(
if not repo:
return {"ok": False, "error": "invalid_repository_name"}
path_timeout = (
path_timeout if path_timeout is not None else RNS.Transport.PATH_REQUEST_TIMEOUT
path_timeout = _clamp_timeout(
path_timeout,
default=float(RNS.Transport.PATH_REQUEST_TIMEOUT),
minimum=1.0,
maximum=_MAX_RNGIT_PATH_TIMEOUT_S,
)
link_timeout = _clamp_timeout(
link_timeout,
default=30.0,
minimum=1.0,
maximum=_MAX_RNGIT_LINK_TIMEOUT_S,
)
list_timeout = _clamp_timeout(
list_timeout,
default=120.0,
minimum=1.0,
maximum=_MAX_RNGIT_LIST_TIMEOUT_S,
)
link_timeout = link_timeout if link_timeout is not None else 30.0
list_timeout = list_timeout if list_timeout is not None else 120.0
destination_hash = bytes.fromhex(norm_hash)
@@ -1,11 +1,13 @@
# SPDX-License-Identifier: 0BSD
import asyncio
import logging
import time
import traceback
import RNS
logger = logging.getLogger(__name__)
class RNPathTraceHandler:
def __init__(self, reticulum_instance, identity):
@@ -98,5 +100,6 @@ class RNPathTraceHandler:
"interface": next_hop_interface,
"next_hop": next_hop_bytes.hex() if next_hop_bytes else None,
}
except Exception as e:
return {"error": f"Trace failed: {e}\n{traceback.format_exc()}"}
except Exception:
logger.exception("RN path trace failed")
return {"error": "Trace failed"}
+37
View File
@@ -1,20 +1,55 @@
# SPDX-License-Identifier: 0BSD
import asyncio
import logging
import os
import time
import RNS
logger = logging.getLogger(__name__)
class RNProbeHandler:
DEFAULT_PROBE_SIZE = 16
DEFAULT_TIMEOUT = 12
MAX_PROBES = 50
MAX_WAIT_S = 120.0
MAX_TIMEOUT_S = 600.0
_PACKET_OVERHEAD_BYTES = 96
def __init__(self, reticulum_instance, identity):
self.reticulum = reticulum_instance
self.identity = identity
def _max_payload_bytes(self) -> int:
mtu = int(RNS.Reticulum.MTU)
return max(1, mtu - self._PACKET_OVERHEAD_BYTES)
def _validate_probe_params(
self,
size: int,
timeout: float | None,
wait: float,
probes: int,
) -> None:
if probes < 1 or probes > self.MAX_PROBES:
msg = f"probes must be between 1 and {self.MAX_PROBES}"
raise ValueError(msg)
if wait < 0 or wait > self.MAX_WAIT_S:
msg = f"wait must be between 0 and {self.MAX_WAIT_S} seconds"
raise ValueError(msg)
max_payload = self._max_payload_bytes()
if size < 1 or size > max_payload:
msg = f"size must be between 1 and {max_payload} bytes for this MTU"
raise ValueError(msg)
if timeout is not None:
if timeout <= 0 or timeout > self.MAX_TIMEOUT_S:
msg = (
f"timeout must be positive and at most {self.MAX_TIMEOUT_S} seconds"
)
raise ValueError(msg)
async def probe_destination(
self,
destination_hash: bytes,
@@ -24,6 +59,8 @@ class RNProbeHandler:
wait: float = 0,
probes: int = 1,
):
self._validate_probe_params(size, timeout, wait, probes)
try:
app_name, aspects = RNS.Destination.app_and_aspects_from_name(full_name)
except Exception as e:
@@ -153,3 +153,28 @@ def test_setup_receive_destination_idempotent_restarts_listener(
assert mock_transport.deregister_destination.call_count == 1
assert mock_transport.deregister_destination.call_args[0][0] is first_dest
assert rncp_handler.receive_destination is second_dest
@patch("meshchatx.src.backend.rncp_handler.RNS.Transport")
@patch("meshchatx.src.backend.rncp_handler.RNS.Identity")
@patch("meshchatx.src.backend.rncp_handler.RNS.Destination")
@patch("meshchatx.src.backend.rncp_handler.RNS.Reticulum")
def test_setup_receive_destination_empty_allowlist_clears_previous(
mock_rns_reticulum,
mock_dest_class,
mock_identity_class,
mock_transport,
rncp_handler,
):
mock_rns_reticulum.identitypath = "/tmp/rns/identities"
mock_id_obj = MagicMock()
mock_identity_class.from_file.return_value = mock_id_obj
mock_dest_obj = MagicMock()
mock_dest_obj.hash = b"\x04" * 16
mock_dest_class.return_value = mock_dest_obj
with patch("os.path.isfile", return_value=True):
rncp_handler.setup_receive_destination(allowed_hashes=["cd" * 16])
assert len(rncp_handler.allowed_identity_hashes) == 1
rncp_handler.setup_receive_destination(allowed_hashes=[])
assert rncp_handler.allowed_identity_hashes == []
@@ -0,0 +1,40 @@
# SPDX-License-Identifier: 0BSD
from unittest.mock import MagicMock, patch
import pytest
import RNS
from meshchatx.src.backend.rngit_tool import _clamp_timeout
from meshchatx.src.backend.rnprobe_handler import RNProbeHandler
@patch.object(RNS.Reticulum, "MTU", 500)
def test_rnprobe_rejects_negative_probes():
h = RNProbeHandler(MagicMock(), MagicMock())
with pytest.raises(ValueError, match="probes"):
h._validate_probe_params(16, None, 0, 0)
with pytest.raises(ValueError, match="probes"):
h._validate_probe_params(16, None, 0, 51)
@patch.object(RNS.Reticulum, "MTU", 500)
def test_rnprobe_rejects_oversized_payload():
h = RNProbeHandler(MagicMock(), MagicMock())
max_b = h._max_payload_bytes()
with pytest.raises(ValueError, match="size"):
h._validate_probe_params(max_b + 1, None, 0, 1)
@patch.object(RNS.Reticulum, "MTU", 500)
def test_rnprobe_accepts_max_payload():
h = RNProbeHandler(MagicMock(), MagicMock())
max_b = h._max_payload_bytes()
h._validate_probe_params(max_b, None, 0, 1)
def test_rngit_clamp_timeout():
assert _clamp_timeout(None, default=30.0, minimum=1.0, maximum=100.0) == 30.0
assert _clamp_timeout(5000.0, default=30.0, minimum=1.0, maximum=100.0) == 100.0
assert _clamp_timeout(0.5, default=30.0, minimum=1.0, maximum=100.0) == 1.0
assert _clamp_timeout(50.0, default=30.0, minimum=1.0, maximum=100.0) == 50.0