fixup! Add MAS registration shim for legacy Matrix APIs

This commit is contained in:
Olivier 'reivilibre
2026-04-27 18:35:38 +01:00
parent d9956b735b
commit 9fce7bac9b
+52 -28
View File
@@ -18,12 +18,11 @@ import hmac
import json
import os
import secrets
import time
import urllib.error
import urllib.parse
import urllib.request
from http.server import HTTPServer, BaseHTTPRequestHandler
from http.server import BaseHTTPRequestHandler, HTTPServer
MAS_PORT = int(os.environ.get("MAS_PORT", "8081"))
MAS_BASE = f"http://localhost:{MAS_PORT}"
@@ -42,13 +41,18 @@ _nonces: dict[str, float] = {}
def get_admin_token() -> str:
"""Obtain an admin access token from MAS via client_credentials grant using HTTP Basic Auth."""
import base64
token_url = f"{MAS_BASE}/oauth2/token"
data = urllib.parse.urlencode({
"grant_type": "client_credentials",
"scope": "urn:mas:admin",
}).encode()
creds = base64.b64encode(f"{MAS_ADMIN_CLIENT_ID}:{MAS_ADMIN_CLIENT_SECRET}".encode()).decode()
token_url = f"{MAS_BASE}/oauth2/token"
data = urllib.parse.urlencode(
{
"grant_type": "client_credentials",
"scope": "urn:mas:admin",
}
).encode()
creds = base64.b64encode(
f"{MAS_ADMIN_CLIENT_ID}:{MAS_ADMIN_CLIENT_SECRET}".encode()
).decode()
for attempt in range(30):
try:
@@ -58,12 +62,23 @@ def get_admin_token() -> str:
with urllib.request.urlopen(req, timeout=5) as resp:
body = json.loads(resp.read())
return body["access_token"]
except (urllib.error.URLError, urllib.error.HTTPError, KeyError, json.JSONDecodeError) as e:
except (
urllib.error.URLError,
urllib.error.HTTPError,
KeyError,
json.JSONDecodeError,
) as e:
if attempt < 29:
print(f"[shim] Waiting for MAS token endpoint (attempt {attempt + 1}): {e}")
print(
f"[shim] Waiting for MAS token endpoint (attempt {attempt + 1}): {e}"
)
time.sleep(1)
else:
raise RuntimeError(f"Failed to obtain admin token from MAS after 30 attempts: {e}")
raise RuntimeError(
f"Failed to obtain admin token from MAS after 30 attempts: {e}"
)
else:
raise RuntimeError()
def create_user_in_mas(admin_token: str, username: str) -> str:
@@ -86,7 +101,7 @@ def set_password_in_mas(admin_token: str, user_id: str, password: str) -> None:
req = urllib.request.Request(url, data=payload, method="POST")
req.add_header("Content-Type", "application/json")
req.add_header("Authorization", f"Bearer {admin_token}")
with urllib.request.urlopen(req, timeout=10) as resp:
with urllib.request.urlopen(req, timeout=10):
pass # Expects 204 No Content
@@ -97,18 +112,20 @@ def set_admin_in_mas(admin_token: str, user_id: str) -> None:
req = urllib.request.Request(url, data=payload, method="POST")
req.add_header("Content-Type", "application/json")
req.add_header("Authorization", f"Bearer {admin_token}")
with urllib.request.urlopen(req, timeout=10) as resp:
with urllib.request.urlopen(req, timeout=10):
pass # Expects 200 OK
def login_via_mas(username: str, password: str, admin: bool = False) -> dict:
"""Login via MAS compat API to get an access token and device_id."""
url = f"{MAS_BASE}/_matrix/client/v3/login"
payload = json.dumps({
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": username},
"password": password,
}).encode()
payload = json.dumps(
{
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": username},
"password": password,
}
).encode()
req = urllib.request.Request(url, data=payload, method="POST")
req.add_header("Content-Type", "application/json")
with urllib.request.urlopen(req, timeout=10) as resp:
@@ -133,7 +150,7 @@ def register_user(username: str, password: str, admin: bool = False) -> dict:
class ShimHandler(BaseHTTPRequestHandler):
"""HTTP handler for the registration shim."""
def log_message(self, fmt, *args):
def log_message(self, fmt: object, *args: object) -> None:
print(f"[shim] {args[0]}")
def _send_json(self, code: int, obj: dict) -> None:
@@ -148,7 +165,7 @@ class ShimHandler(BaseHTTPRequestHandler):
length = int(self.headers.get("Content-Length", 0))
return self.rfile.read(length) if length > 0 else b""
def do_GET(self):
def do_GET(self) -> None:
if self.path == "/health":
self._send_json(200, {"ok": True})
return
@@ -161,7 +178,7 @@ class ShimHandler(BaseHTTPRequestHandler):
self._send_json(404, {"error": "Not found"})
def do_POST(self):
def do_POST(self) -> None:
try:
if self.path == "/_synapse/admin/v1/register":
self._handle_admin_register()
@@ -227,11 +244,14 @@ class ShimHandler(BaseHTTPRequestHandler):
password = data.get("password", "")
if auth_type != "m.login.dummy":
self._send_json(401, {
"flows": [{"stages": ["m.login.dummy"]}],
"params": {},
"session": secrets.token_hex(16),
})
self._send_json(
401,
{
"flows": [{"stages": ["m.login.dummy"]}],
"params": {},
"session": secrets.token_hex(16),
},
)
return
if not username or not password:
@@ -254,11 +274,15 @@ def _get_cached_admin_token() -> str:
return _admin_token
def main():
def main() -> None:
print(f"[shim] Starting MAS Registration Shim on port {SHIM_PORT}")
print(f"[shim] MAS endpoint: {MAS_BASE}")
print(f"[shim] Server name: {SERVER_NAME}")
print(f"[shim] Admin client ID: {MAS_ADMIN_CLIENT_ID[:8]}..." if MAS_ADMIN_CLIENT_ID else "[shim] WARNING: No admin client ID configured")
print(
f"[shim] Admin client ID: {MAS_ADMIN_CLIENT_ID[:8]}..."
if MAS_ADMIN_CLIENT_ID
else "[shim] WARNING: No admin client ID configured"
)
server = HTTPServer(("0.0.0.0", SHIM_PORT), ShimHandler)
print(f"[shim] Listening on port {SHIM_PORT}")