diff --git a/docker/complement/conf/mas_registration_shim.py b/docker/complement/conf/mas_registration_shim.py old mode 100644 new mode 100755 index 8d74a9767b..12217e9e02 --- a/docker/complement/conf/mas_registration_shim.py +++ b/docker/complement/conf/mas_registration_shim.py @@ -18,12 +18,11 @@ import hmac import json import os import secrets - import time import urllib.error import urllib.parse import urllib.request -from http.server import HTTPServer, BaseHTTPRequestHandler +from http.server import BaseHTTPRequestHandler, HTTPServer MAS_PORT = int(os.environ.get("MAS_PORT", "8081")) MAS_BASE = f"http://localhost:{MAS_PORT}" @@ -42,13 +41,18 @@ _nonces: dict[str, float] = {} def get_admin_token() -> str: """Obtain an admin access token from MAS via client_credentials grant using HTTP Basic Auth.""" import base64 - token_url = f"{MAS_BASE}/oauth2/token" - data = urllib.parse.urlencode({ - "grant_type": "client_credentials", - "scope": "urn:mas:admin", - }).encode() - creds = base64.b64encode(f"{MAS_ADMIN_CLIENT_ID}:{MAS_ADMIN_CLIENT_SECRET}".encode()).decode() + token_url = f"{MAS_BASE}/oauth2/token" + data = urllib.parse.urlencode( + { + "grant_type": "client_credentials", + "scope": "urn:mas:admin", + } + ).encode() + + creds = base64.b64encode( + f"{MAS_ADMIN_CLIENT_ID}:{MAS_ADMIN_CLIENT_SECRET}".encode() + ).decode() for attempt in range(30): try: @@ -58,12 +62,23 @@ def get_admin_token() -> str: with urllib.request.urlopen(req, timeout=5) as resp: body = json.loads(resp.read()) return body["access_token"] - except (urllib.error.URLError, urllib.error.HTTPError, KeyError, json.JSONDecodeError) as e: + except ( + urllib.error.URLError, + urllib.error.HTTPError, + KeyError, + json.JSONDecodeError, + ) as e: if attempt < 29: - print(f"[shim] Waiting for MAS token endpoint (attempt {attempt + 1}): {e}") + print( + f"[shim] Waiting for MAS token endpoint (attempt {attempt + 1}): {e}" + ) time.sleep(1) else: - raise RuntimeError(f"Failed to obtain admin token from MAS after 30 attempts: {e}") + raise RuntimeError( + f"Failed to obtain admin token from MAS after 30 attempts: {e}" + ) + else: + raise RuntimeError() def create_user_in_mas(admin_token: str, username: str) -> str: @@ -86,7 +101,7 @@ def set_password_in_mas(admin_token: str, user_id: str, password: str) -> None: req = urllib.request.Request(url, data=payload, method="POST") req.add_header("Content-Type", "application/json") req.add_header("Authorization", f"Bearer {admin_token}") - with urllib.request.urlopen(req, timeout=10) as resp: + with urllib.request.urlopen(req, timeout=10): pass # Expects 204 No Content @@ -97,18 +112,20 @@ def set_admin_in_mas(admin_token: str, user_id: str) -> None: req = urllib.request.Request(url, data=payload, method="POST") req.add_header("Content-Type", "application/json") req.add_header("Authorization", f"Bearer {admin_token}") - with urllib.request.urlopen(req, timeout=10) as resp: + with urllib.request.urlopen(req, timeout=10): pass # Expects 200 OK def login_via_mas(username: str, password: str, admin: bool = False) -> dict: """Login via MAS compat API to get an access token and device_id.""" url = f"{MAS_BASE}/_matrix/client/v3/login" - payload = json.dumps({ - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": username}, - "password": password, - }).encode() + payload = json.dumps( + { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": username}, + "password": password, + } + ).encode() req = urllib.request.Request(url, data=payload, method="POST") req.add_header("Content-Type", "application/json") with urllib.request.urlopen(req, timeout=10) as resp: @@ -133,7 +150,7 @@ def register_user(username: str, password: str, admin: bool = False) -> dict: class ShimHandler(BaseHTTPRequestHandler): """HTTP handler for the registration shim.""" - def log_message(self, fmt, *args): + def log_message(self, fmt: object, *args: object) -> None: print(f"[shim] {args[0]}") def _send_json(self, code: int, obj: dict) -> None: @@ -148,7 +165,7 @@ class ShimHandler(BaseHTTPRequestHandler): length = int(self.headers.get("Content-Length", 0)) return self.rfile.read(length) if length > 0 else b"" - def do_GET(self): + def do_GET(self) -> None: if self.path == "/health": self._send_json(200, {"ok": True}) return @@ -161,7 +178,7 @@ class ShimHandler(BaseHTTPRequestHandler): self._send_json(404, {"error": "Not found"}) - def do_POST(self): + def do_POST(self) -> None: try: if self.path == "/_synapse/admin/v1/register": self._handle_admin_register() @@ -227,11 +244,14 @@ class ShimHandler(BaseHTTPRequestHandler): password = data.get("password", "") if auth_type != "m.login.dummy": - self._send_json(401, { - "flows": [{"stages": ["m.login.dummy"]}], - "params": {}, - "session": secrets.token_hex(16), - }) + self._send_json( + 401, + { + "flows": [{"stages": ["m.login.dummy"]}], + "params": {}, + "session": secrets.token_hex(16), + }, + ) return if not username or not password: @@ -254,11 +274,15 @@ def _get_cached_admin_token() -> str: return _admin_token -def main(): +def main() -> None: print(f"[shim] Starting MAS Registration Shim on port {SHIM_PORT}") print(f"[shim] MAS endpoint: {MAS_BASE}") print(f"[shim] Server name: {SERVER_NAME}") - print(f"[shim] Admin client ID: {MAS_ADMIN_CLIENT_ID[:8]}..." if MAS_ADMIN_CLIENT_ID else "[shim] WARNING: No admin client ID configured") + print( + f"[shim] Admin client ID: {MAS_ADMIN_CLIENT_ID[:8]}..." + if MAS_ADMIN_CLIENT_ID + else "[shim] WARNING: No admin client ID configured" + ) server = HTTPServer(("0.0.0.0", SHIM_PORT), ShimHandler) print(f"[shim] Listening on port {SHIM_PORT}")