diff --git a/synapse/http/aiohttp_shim.py b/synapse/http/aiohttp_shim.py index 584dcf41bf..e5093c2f63 100644 --- a/synapse/http/aiohttp_shim.py +++ b/synapse/http/aiohttp_shim.py @@ -103,8 +103,8 @@ class ShimRequestHeaders: return value return value.encode("utf-8") - def getRawHeaders(self, name: bytes | str) -> list[bytes] | None: - """Return all values for *name* as a list of ``bytes``, or ``None``.""" + def getRawHeaders(self, name: bytes | str, default: Any = None) -> list[bytes] | None: + """Return all values for *name* as a list of ``bytes``, or *default*.""" str_name = self._norm_name(name) values: list[str] = list(self._raw.getall(str_name, [])) if self._raw is not None else [] @@ -112,7 +112,7 @@ class ShimRequestHeaders: values.extend(self._extra[str_name.lower()]) if not values: - return None + return default return [self._to_bytes(v) for v in values] def hasHeader(self, name: bytes | str) -> bool: @@ -192,11 +192,11 @@ class ShimResponseHeaders: self._original_name.setdefault(lower, str_name) self._headers.setdefault(lower, []).append(self._norm_value(value)) - def getRawHeaders(self, name: bytes | str) -> list[bytes] | list[str] | None: + def getRawHeaders(self, name: bytes | str, default: Any = None) -> list[bytes] | list[str] | None: lower = self._norm_name(name).lower() vals = self._headers.get(lower) if vals is None: - return None + return default # Return str if called with str, bytes if called with bytes # (matching Twisted's Headers behavior) if isinstance(name, str): @@ -470,28 +470,57 @@ class SynapseRequest: cls, channel: Any, site: "SynapseSite", - our_server_name: str = "test_server", + method: bytes = b"GET", + uri: bytes = b"/", + body: bytes = b"", + client_ip: str = "127.0.0.1", + our_server_name: str = "", max_request_body_size: int = 50 * 1024 * 1024, ) -> "SynapseRequest": """Create a SynapseRequest for testing without a real aiohttp request. - The caller should set `.content`, `.method`, `.path`, `.uri`, `.args`, - and call `.requestHeaders.addRawHeader()` as needed. + Handles all the setup that Twisted's Request did automatically: + - Splits URI into path and query string + - Parses query string into ``args`` + - Sets up request headers shim + + After construction, callers can add headers via + ``req.requestHeaders.addRawHeader()`` and must call + ``req.prepare(content_is_form=...)`` to finalise args from + form-encoded bodies before dispatching. """ + from urllib.parse import parse_qs, urlparse + obj = object.__new__(cls) obj._aiohttp_request = None obj.synapse_site = site obj.reactor = getattr(site, 'reactor', None) - obj.our_server_name = our_server_name + obj.our_server_name = our_server_name or getattr(site, 'server_name', 'test') - # Request properties — to be populated by test code - obj.method = b"GET" - obj.path = b"/" - obj.uri = b"/" + # --- Request properties --- + obj.method = method + obj.uri = uri # full URI including query string + + # Split URI into path (no query) and query args + uri_str = uri.decode("utf-8") if isinstance(uri, bytes) else uri + parsed = urlparse(uri_str) + obj.path = parsed.path.encode("utf-8") obj.clientproto = b"HTTP/1.1" - obj.args = {} - obj.content = io.BytesIO(b"") + + # Parse query string into args (bytes-keyed, like Twisted) + obj.args: dict[bytes, list[bytes]] = {} + if parsed.query: + for k, vs in parse_qs(parsed.query, keep_blank_values=True).items(): + bk = k.encode("utf-8") + obj.args[bk] = [v.encode("utf-8") for v in vs] + + obj.content = io.BytesIO(body) + obj._client_ip = client_ip + + # --- Headers --- obj.requestHeaders = ShimRequestHeaders(None) + + # --- Response buffering --- obj.responseHeaders = ShimResponseHeaders() obj.code = 200 obj.code_message = b"OK" @@ -500,9 +529,12 @@ class SynapseRequest: obj.sentLength = 0 obj._response_buffer = bytearray() obj._disconnected = False + + # --- X-Forwarded --- obj._forwarded_for = None obj._forwarded_https = False + # --- Synapse-specific --- global _next_request_seq obj.request_seq = _next_request_seq _next_request_seq += 1 @@ -518,11 +550,61 @@ class SynapseRequest: obj._is_processing = False obj._processing_finished_time = None obj.finish_time = None - obj.cookies = [] + obj.cookies: list[bytes] = [] obj.channel = channel - obj._client_ip = "127.0.0.1" return obj + def prepare_for_dispatch(self, content_is_form: bool = False) -> None: + """Finalise request state before dispatching to a handler. + + Parses form-encoded POST bodies into ``self.args`` (matching Twisted's + automatic behaviour) and sets up metrics / logging context. + """ + from urllib.parse import parse_qs + + # Merge form-encoded body into args + if content_is_form: + body = self.content.read() + self.content.seek(0) + if body: + body_str = body.decode("utf-8") if isinstance(body, bytes) else body + for k, vs in parse_qs(body_str, keep_blank_values=True).items(): + bk = k.encode("utf-8") + self.args.setdefault(bk, []).extend( + v.encode("utf-8") for v in vs + ) + + # Set up metrics and logging context + self.start_render(self.__class__.__name__) + + def dispatch(self, root_resource: Any) -> None: + """Resolve the target resource and dispatch this request. + + Walks the resource tree to find the handler, then kicks off + the async render as an ``asyncio.Task``. + """ + target = _resolve_resource(root_resource, self.path) + + if hasattr(target, '_async_render_wrapper'): + self.render_deferred = asyncio.ensure_future( + target._async_render_wrapper(self) + ) + else: + # Simple resource with render_GET/render_POST etc. + method_str = self.method.decode('ascii') if isinstance(self.method, bytes) else self.method + handler = getattr(target, 'render_' + method_str, None) + if handler: + result = handler(self) + if asyncio.iscoroutine(result): + self.render_deferred = asyncio.ensure_future(result) + else: + from synapse.http.server import respond_with_json + respond_with_json( + self, 404, + {"errcode": "M_UNRECOGNIZED", "error": "Unrecognized request"}, + send_cors=True, + ) + # ------------------------------------------------------------------ # X-Forwarded-For processing # ------------------------------------------------------------------ @@ -778,6 +860,13 @@ class SynapseRequest: aiohttp shim it merely sets the ``finished`` flag — the actual ``aiohttp.web.Response`` is built later by ``build_aiohttp_response()``. """ + # Flush any cookies from self.cookies into Set-Cookie headers. + # Twisted's Request.finish() did this automatically. + for cookie_bytes in self.cookies: + cookie_str = cookie_bytes.decode("utf-8") if isinstance(cookie_bytes, bytes) else cookie_bytes + self.responseHeaders.addRawHeader("Set-Cookie", cookie_str) + self.cookies.clear() + self.finish_time = time.time() self.finished = True