fix: validate request bodies and improve error handling in meshchat API endpoints

This commit is contained in:
Ivan
2026-04-09 04:49:52 -05:00
parent bb910f288b
commit 68c8e6e363
11 changed files with 120 additions and 40 deletions
+65 -16
View File
@@ -2813,6 +2813,11 @@ class ReticulumMeshChat:
{"error": "Invalid JSON body"},
status=400,
)
if not isinstance(data, dict):
return web.json_response(
{"error": "Invalid request body"},
status=400,
)
password = data.get("password")
if not password or len(password) < 8:
@@ -2838,6 +2843,8 @@ class ReticulumMeshChat:
self.config.auth_password_hash.set(password_hash)
session = await get_session(request)
session.invalidate()
session = await get_session(request)
session["authenticated"] = True
session["identity_hash"] = self.identity.hash.hex()
@@ -2885,6 +2892,11 @@ class ReticulumMeshChat:
{"error": "Invalid JSON body"},
status=400,
)
if not isinstance(data, dict):
return web.json_response(
{"error": "Invalid request body"},
status=400,
)
password = data.get("password")
password_hash = self.config.auth_password_hash.get()
@@ -2924,6 +2936,8 @@ class ReticulumMeshChat:
password.encode("utf-8"),
password_hash.encode("utf-8"),
):
session = await get_session(request)
session.invalidate()
session = await get_session(request)
session["authenticated"] = True
session["identity_hash"] = self.identity.hash.hex()
@@ -3983,7 +3997,7 @@ class ReticulumMeshChat:
if is_connected_to_shared_instance:
# Try to find the shared instance address from active connections
try:
for conn in process.connections(kind="all"):
for conn in process.net_connections(kind="all"):
if conn.status == psutil.CONN_ESTABLISHED and conn.raddr:
# Check for common Reticulum shared instance ports or UNIX sockets
if (
@@ -6524,8 +6538,14 @@ class ReticulumMeshChat:
async def get_all_archived_pages(request):
# get search query and pagination from request
query = request.query.get("q", "").strip()
page = int(request.query.get("page", 1))
limit = int(request.query.get("limit", 15))
try:
page = max(1, int(request.query.get("page", 1)))
except (ValueError, TypeError):
page = 1
try:
limit = max(1, min(100, int(request.query.get("limit", 15))))
except (ValueError, TypeError):
limit = 15
offset = (page - 1) * limit
# fetch archived pages from database
@@ -8000,10 +8020,10 @@ class ReticulumMeshChat:
},
)
except Exception as e:
except Exception:
return web.json_response(
{
"message": f"Sending Failed: {e!s}",
"message": "Sending failed",
},
status=503,
)
@@ -8101,7 +8121,7 @@ class ReticulumMeshChat:
self.message_handler.get_conversation_messages,
local_hash,
destination_hash,
limit=int(count) if count else 100,
limit=min(int(count), 1000) if count else 100,
after_id=after_id if order == "asc" else None,
before_id=after_id if order == "desc" else None,
)
@@ -8138,7 +8158,10 @@ class ReticulumMeshChat:
# handle image
if attachment_type == "image" and "image" in fields:
image_data = base64.b64decode(fields["image"]["image_bytes"])
allowed_image_types = {"png", "jpeg", "jpg", "gif", "webp", "bmp"}
image_type = fields["image"]["image_type"]
if image_type.lower() not in allowed_image_types:
image_type = "png"
return web.Response(body=image_data, content_type=f"image/{image_type}")
# handle audio
@@ -8154,13 +8177,24 @@ class ReticulumMeshChat:
if file_index is not None:
try:
index = int(file_index)
if index < 0:
return web.json_response(
{"message": "Invalid file index"}, status=400,
)
file_attachment = fields["file_attachments"][index]
file_data = base64.b64decode(file_attachment["file_bytes"])
safe_name = (
os.path.basename(file_attachment["file_name"])
.replace('"', "_")
.replace("\r", "")
.replace("\n", "")
.replace("\x00", "")
) or "download"
return web.Response(
body=file_data,
content_type="application/octet-stream",
headers={
"Content-Disposition": f'attachment; filename="{file_attachment["file_name"]}"',
"Content-Disposition": f'attachment; filename="{safe_name}"',
},
)
except (ValueError, IndexError):
@@ -8460,7 +8494,7 @@ class ReticulumMeshChat:
return web.json_response({"message": "Folders and mappings imported"})
# mark lxmf conversation as read
@routes.get("/api/v1/lxmf/conversations/{destination_hash}/mark-as-read")
@routes.post("/api/v1/lxmf/conversations/{destination_hash}/mark-as-read")
async def lxmf_conversations_mark_read(request):
# get path params
destination_hash = request.match_info.get("destination_hash", "")
@@ -8665,10 +8699,9 @@ class ReticulumMeshChat:
)
except Exception as e:
RNS.log(f"Error in notifications_get: {e}", RNS.LOG_ERROR)
import traceback
traceback.print_exc()
return web.json_response({"error": str(e)}, status=500)
return web.json_response(
{"error": "Internal error"}, status=500,
)
# get blocked destinations
@routes.get("/api/v1/blocked-destinations")
@@ -8933,7 +8966,12 @@ class ReticulumMeshChat:
return web.json_response({"message": "Offline map disabled"})
mbtiles_dir = self.map_manager.get_mbtiles_dir()
file_path = os.path.join(mbtiles_dir, filename)
safe_name = os.path.basename(filename)
file_path = os.path.join(mbtiles_dir, safe_name)
resolved = os.path.realpath(file_path)
base = os.path.realpath(mbtiles_dir)
if not resolved.startswith(base + os.sep):
return web.json_response({"error": "Invalid filename"}, status=400)
if os.path.exists(file_path):
self.map_manager.close()
self.config.map_offline_path.set(file_path)
@@ -9086,19 +9124,24 @@ class ReticulumMeshChat:
if field.name != "file":
return web.json_response({"error": "No file field"}, status=400)
filename = field.filename
filename = os.path.basename(field.filename or "")
if not filename.endswith(".mbtiles"):
return web.json_response(
{"error": "Invalid file format, must be .mbtiles"},
status=400,
)
# save to mbtiles dir
mbtiles_dir = self.map_manager.get_mbtiles_dir()
if not os.path.exists(mbtiles_dir):
os.makedirs(mbtiles_dir)
dest_path = os.path.join(mbtiles_dir, filename)
resolved = os.path.realpath(dest_path)
base = os.path.realpath(mbtiles_dir)
if not resolved.startswith(base + os.sep):
return web.json_response(
{"error": "Invalid filename"}, status=400,
)
size = 0
with open(dest_path, "wb") as f:
@@ -9587,7 +9630,13 @@ class ReticulumMeshChat:
if path.endswith("/"):
path += "index.html"
local_path = os.path.join(dm.docs_dir, path)
try:
local_path = os.path.realpath(os.path.join(dm.docs_dir, path))
base = os.path.realpath(dm.docs_dir)
except (ValueError, OSError):
return web.json_response({"error": "Invalid path"}, status=400)
if not local_path.startswith(base + os.sep) and local_path != base:
return web.json_response({"error": "Invalid path"}, status=400)
if os.path.exists(local_path) and os.path.isfile(local_path):
return web.FileResponse(local_path)
+23 -4
View File
@@ -303,7 +303,13 @@ class DocsManager:
return sorted(docs, key=lambda x: x["name"])
def get_doc_content(self, path):
full_path = os.path.join(self.meshchatx_docs_dir, path)
try:
full_path = os.path.realpath(os.path.join(self.meshchatx_docs_dir, path))
base = os.path.realpath(self.meshchatx_docs_dir)
except (ValueError, OSError):
return None
if not full_path.startswith(base + os.sep) and full_path != base:
return None
if not os.path.exists(full_path):
return None
@@ -597,8 +603,16 @@ class DocsManager:
return False
def _extract_docs(self, zip_path, version):
# Target dir for this version
version_dir = os.path.join(self.versions_dir, version)
safe_version = os.path.basename(version)
if not safe_version or safe_version in (".", ".."):
raise ValueError(f"Invalid version name: {version}")
version_dir = os.path.join(self.versions_dir, safe_version)
resolved = os.path.realpath(version_dir)
base = os.path.realpath(self.versions_dir)
if not resolved.startswith(base + os.sep):
raise ValueError(f"Invalid version name: {version}")
if os.path.exists(version_dir):
shutil.rmtree(version_dir)
os.makedirs(version_dir)
@@ -623,6 +637,8 @@ class DocsManager:
if has_docs_subfolder:
members_to_extract = [m for m in namelist if m.startswith(docs_prefix)]
for member in members_to_extract:
if ".." in member.split("/"):
continue
zip_ref.extract(member, temp_extract)
src_path = os.path.join(temp_extract, root_folder, "docs")
@@ -635,7 +651,10 @@ class DocsManager:
else:
shutil.copy2(s, d)
else:
zip_ref.extractall(temp_extract)
safe_members = [
m for m in namelist if ".." not in m.split("/")
]
zip_ref.extractall(temp_extract, members=safe_members)
src_path = os.path.join(temp_extract, root_folder)
if os.path.exists(src_path) and os.path.isdir(src_path):
for item in os.listdir(src_path):
+6 -1
View File
@@ -76,7 +76,12 @@ class MapManager:
def delete_mbtiles(self, filename):
mbtiles_dir = self.get_mbtiles_dir()
file_path = os.path.join(mbtiles_dir, filename)
safe_name = os.path.basename(filename)
file_path = os.path.join(mbtiles_dir, safe_name)
resolved = os.path.realpath(file_path)
base = os.path.realpath(mbtiles_dir)
if not resolved.startswith(base + os.sep):
return False
if os.path.exists(file_path) and file_path.endswith(".mbtiles"):
if file_path == self.get_offline_path():
self.config.map_offline_path.set(None)
+10 -10
View File
@@ -13,17 +13,17 @@ def convert_nomadnet_string_data_to_map(path_data: str | None):
def convert_nomadnet_field_data_to_map(field_data):
if field_data is None:
return None
data = {}
if field_data is not None or "{}":
try:
json_data = field_data
if isinstance(json_data, dict):
data = {f"field_{key}": value for key, value in json_data.items()}
else:
return None
except Exception as e:
print(f"skipping invalid field data: {e}")
try:
if isinstance(field_data, dict):
data = {f"field_{key}": value for key, value in field_data.items()}
else:
return None
except Exception as e:
print(f"skipping invalid field data: {e}")
return None
return data
+3 -2
View File
@@ -149,8 +149,9 @@ class RNCPHandler:
if self.fetch_jail:
if data.startswith(self.fetch_jail + "/"):
data = data.replace(self.fetch_jail + "/", "")
file_path = os.path.abspath(os.path.expanduser(f"{self.fetch_jail}/{data}"))
if not file_path.startswith(self.fetch_jail + "/"):
file_path = os.path.realpath(os.path.expanduser(f"{self.fetch_jail}/{data}"))
jail_real = os.path.realpath(self.fetch_jail)
if not file_path.startswith(jail_real + "/"):
return self.REQ_FETCH_NOT_ALLOWED
else:
file_path = os.path.abspath(os.path.expanduser(data))
@@ -325,7 +325,7 @@ export default {
for (const conversation of conversations) {
if (conversation.is_unread) {
try {
await window.api.get(
await window.api.post(
`/api/v1/lxmf/conversations/${conversation.destination_hash}/mark-as-read`
);
} catch (e) {
@@ -3986,7 +3986,7 @@ export default {
// mark conversation as read on server
try {
await window.api.get(`/api/v1/lxmf/conversations/${conversation.destination_hash}/mark-as-read`);
await window.api.post(`/api/v1/lxmf/conversations/${conversation.destination_hash}/mark-as-read`);
} catch (e) {
// do nothing if failed to mark as read
console.log(e);
+6 -2
View File
@@ -66,11 +66,15 @@ def generate_ssl_certificate(cert_path: str, key_path: str):
with open(cert_path, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
with open(key_path, "wb") as f:
f.write(
key_fd = os.open(key_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
try:
os.write(
key_fd,
private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
),
)
finally:
os.close(key_fd)
+1 -1
View File
@@ -321,7 +321,7 @@
"path": "/api/v1/lxmf/conversations/move-to-folder"
},
{
"method": "GET",
"method": "POST",
"path": "/api/v1/lxmf/conversations/{destination_hash}/mark-as-read"
},
{
+3 -1
View File
@@ -37,7 +37,9 @@ from meshchatx.src.backend.message_handler import MessageHandler
# Strings that are valid for most text columns but include adversarial chars
st_nasty_text = st.text(
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S", "Z", "C")),
alphabet=st.characters(
whitelist_categories=("L", "N", "P", "S", "Z", "Cc", "Cf", "Cn", "Co"),
),
min_size=0,
max_size=300,
)
+1 -1
View File
@@ -539,7 +539,7 @@ describe("NotificationBell clear all", () => {
await wrapper.vm.clearAllNotifications();
await new Promise((r) => setTimeout(r, 100));
const readCalls = global.api.get.mock.calls.filter((c) => c[0]?.includes("/mark-as-read"));
const readCalls = global.api.post.mock.calls.filter((c) => c[0]?.includes("/mark-as-read"));
expect(readCalls.length).toBe(1);
expect(readCalls[0][0]).toContain("conv1");