make cookies work for SSO tests; move request management into the aiohttp shim

This commit is contained in:
Matthew Hodgson
2026-03-22 22:24:12 +00:00
parent 7a16b3aea8
commit 21544b37ed
+106 -17
View File
@@ -103,8 +103,8 @@ class ShimRequestHeaders:
return value
return value.encode("utf-8")
def getRawHeaders(self, name: bytes | str) -> list[bytes] | None:
"""Return all values for *name* as a list of ``bytes``, or ``None``."""
def getRawHeaders(self, name: bytes | str, default: Any = None) -> list[bytes] | None:
"""Return all values for *name* as a list of ``bytes``, or *default*."""
str_name = self._norm_name(name)
values: list[str] = list(self._raw.getall(str_name, [])) if self._raw is not None else []
@@ -112,7 +112,7 @@ class ShimRequestHeaders:
values.extend(self._extra[str_name.lower()])
if not values:
return None
return default
return [self._to_bytes(v) for v in values]
def hasHeader(self, name: bytes | str) -> bool:
@@ -192,11 +192,11 @@ class ShimResponseHeaders:
self._original_name.setdefault(lower, str_name)
self._headers.setdefault(lower, []).append(self._norm_value(value))
def getRawHeaders(self, name: bytes | str) -> list[bytes] | list[str] | None:
def getRawHeaders(self, name: bytes | str, default: Any = None) -> list[bytes] | list[str] | None:
lower = self._norm_name(name).lower()
vals = self._headers.get(lower)
if vals is None:
return None
return default
# Return str if called with str, bytes if called with bytes
# (matching Twisted's Headers behavior)
if isinstance(name, str):
@@ -470,28 +470,57 @@ class SynapseRequest:
cls,
channel: Any,
site: "SynapseSite",
our_server_name: str = "test_server",
method: bytes = b"GET",
uri: bytes = b"/",
body: bytes = b"",
client_ip: str = "127.0.0.1",
our_server_name: str = "",
max_request_body_size: int = 50 * 1024 * 1024,
) -> "SynapseRequest":
"""Create a SynapseRequest for testing without a real aiohttp request.
The caller should set `.content`, `.method`, `.path`, `.uri`, `.args`,
and call `.requestHeaders.addRawHeader()` as needed.
Handles all the setup that Twisted's Request did automatically:
- Splits URI into path and query string
- Parses query string into ``args``
- Sets up request headers shim
After construction, callers can add headers via
``req.requestHeaders.addRawHeader()`` and must call
``req.prepare(content_is_form=...)`` to finalise args from
form-encoded bodies before dispatching.
"""
from urllib.parse import parse_qs, urlparse
obj = object.__new__(cls)
obj._aiohttp_request = None
obj.synapse_site = site
obj.reactor = getattr(site, 'reactor', None)
obj.our_server_name = our_server_name
obj.our_server_name = our_server_name or getattr(site, 'server_name', 'test')
# Request properties — to be populated by test code
obj.method = b"GET"
obj.path = b"/"
obj.uri = b"/"
# --- Request properties ---
obj.method = method
obj.uri = uri # full URI including query string
# Split URI into path (no query) and query args
uri_str = uri.decode("utf-8") if isinstance(uri, bytes) else uri
parsed = urlparse(uri_str)
obj.path = parsed.path.encode("utf-8")
obj.clientproto = b"HTTP/1.1"
obj.args = {}
obj.content = io.BytesIO(b"")
# Parse query string into args (bytes-keyed, like Twisted)
obj.args: dict[bytes, list[bytes]] = {}
if parsed.query:
for k, vs in parse_qs(parsed.query, keep_blank_values=True).items():
bk = k.encode("utf-8")
obj.args[bk] = [v.encode("utf-8") for v in vs]
obj.content = io.BytesIO(body)
obj._client_ip = client_ip
# --- Headers ---
obj.requestHeaders = ShimRequestHeaders(None)
# --- Response buffering ---
obj.responseHeaders = ShimResponseHeaders()
obj.code = 200
obj.code_message = b"OK"
@@ -500,9 +529,12 @@ class SynapseRequest:
obj.sentLength = 0
obj._response_buffer = bytearray()
obj._disconnected = False
# --- X-Forwarded ---
obj._forwarded_for = None
obj._forwarded_https = False
# --- Synapse-specific ---
global _next_request_seq
obj.request_seq = _next_request_seq
_next_request_seq += 1
@@ -518,11 +550,61 @@ class SynapseRequest:
obj._is_processing = False
obj._processing_finished_time = None
obj.finish_time = None
obj.cookies = []
obj.cookies: list[bytes] = []
obj.channel = channel
obj._client_ip = "127.0.0.1"
return obj
def prepare_for_dispatch(self, content_is_form: bool = False) -> None:
"""Finalise request state before dispatching to a handler.
Parses form-encoded POST bodies into ``self.args`` (matching Twisted's
automatic behaviour) and sets up metrics / logging context.
"""
from urllib.parse import parse_qs
# Merge form-encoded body into args
if content_is_form:
body = self.content.read()
self.content.seek(0)
if body:
body_str = body.decode("utf-8") if isinstance(body, bytes) else body
for k, vs in parse_qs(body_str, keep_blank_values=True).items():
bk = k.encode("utf-8")
self.args.setdefault(bk, []).extend(
v.encode("utf-8") for v in vs
)
# Set up metrics and logging context
self.start_render(self.__class__.__name__)
def dispatch(self, root_resource: Any) -> None:
"""Resolve the target resource and dispatch this request.
Walks the resource tree to find the handler, then kicks off
the async render as an ``asyncio.Task``.
"""
target = _resolve_resource(root_resource, self.path)
if hasattr(target, '_async_render_wrapper'):
self.render_deferred = asyncio.ensure_future(
target._async_render_wrapper(self)
)
else:
# Simple resource with render_GET/render_POST etc.
method_str = self.method.decode('ascii') if isinstance(self.method, bytes) else self.method
handler = getattr(target, 'render_' + method_str, None)
if handler:
result = handler(self)
if asyncio.iscoroutine(result):
self.render_deferred = asyncio.ensure_future(result)
else:
from synapse.http.server import respond_with_json
respond_with_json(
self, 404,
{"errcode": "M_UNRECOGNIZED", "error": "Unrecognized request"},
send_cors=True,
)
# ------------------------------------------------------------------
# X-Forwarded-For processing
# ------------------------------------------------------------------
@@ -778,6 +860,13 @@ class SynapseRequest:
aiohttp shim it merely sets the ``finished`` flag — the actual
``aiohttp.web.Response`` is built later by ``build_aiohttp_response()``.
"""
# Flush any cookies from self.cookies into Set-Cookie headers.
# Twisted's Request.finish() did this automatically.
for cookie_bytes in self.cookies:
cookie_str = cookie_bytes.decode("utf-8") if isinstance(cookie_bytes, bytes) else cookie_bytes
self.responseHeaders.addRawHeader("Set-Cookie", cookie_str)
self.cookies.clear()
self.finish_time = time.time()
self.finished = True