mirror of
https://github.com/element-hq/synapse.git
synced 2026-05-17 18:35:46 +00:00
make cookies work for SSO tests; move request management into the aiohttp shim
This commit is contained in:
+106
-17
@@ -103,8 +103,8 @@ class ShimRequestHeaders:
|
||||
return value
|
||||
return value.encode("utf-8")
|
||||
|
||||
def getRawHeaders(self, name: bytes | str) -> list[bytes] | None:
|
||||
"""Return all values for *name* as a list of ``bytes``, or ``None``."""
|
||||
def getRawHeaders(self, name: bytes | str, default: Any = None) -> list[bytes] | None:
|
||||
"""Return all values for *name* as a list of ``bytes``, or *default*."""
|
||||
str_name = self._norm_name(name)
|
||||
|
||||
values: list[str] = list(self._raw.getall(str_name, [])) if self._raw is not None else []
|
||||
@@ -112,7 +112,7 @@ class ShimRequestHeaders:
|
||||
values.extend(self._extra[str_name.lower()])
|
||||
|
||||
if not values:
|
||||
return None
|
||||
return default
|
||||
return [self._to_bytes(v) for v in values]
|
||||
|
||||
def hasHeader(self, name: bytes | str) -> bool:
|
||||
@@ -192,11 +192,11 @@ class ShimResponseHeaders:
|
||||
self._original_name.setdefault(lower, str_name)
|
||||
self._headers.setdefault(lower, []).append(self._norm_value(value))
|
||||
|
||||
def getRawHeaders(self, name: bytes | str) -> list[bytes] | list[str] | None:
|
||||
def getRawHeaders(self, name: bytes | str, default: Any = None) -> list[bytes] | list[str] | None:
|
||||
lower = self._norm_name(name).lower()
|
||||
vals = self._headers.get(lower)
|
||||
if vals is None:
|
||||
return None
|
||||
return default
|
||||
# Return str if called with str, bytes if called with bytes
|
||||
# (matching Twisted's Headers behavior)
|
||||
if isinstance(name, str):
|
||||
@@ -470,28 +470,57 @@ class SynapseRequest:
|
||||
cls,
|
||||
channel: Any,
|
||||
site: "SynapseSite",
|
||||
our_server_name: str = "test_server",
|
||||
method: bytes = b"GET",
|
||||
uri: bytes = b"/",
|
||||
body: bytes = b"",
|
||||
client_ip: str = "127.0.0.1",
|
||||
our_server_name: str = "",
|
||||
max_request_body_size: int = 50 * 1024 * 1024,
|
||||
) -> "SynapseRequest":
|
||||
"""Create a SynapseRequest for testing without a real aiohttp request.
|
||||
|
||||
The caller should set `.content`, `.method`, `.path`, `.uri`, `.args`,
|
||||
and call `.requestHeaders.addRawHeader()` as needed.
|
||||
Handles all the setup that Twisted's Request did automatically:
|
||||
- Splits URI into path and query string
|
||||
- Parses query string into ``args``
|
||||
- Sets up request headers shim
|
||||
|
||||
After construction, callers can add headers via
|
||||
``req.requestHeaders.addRawHeader()`` and must call
|
||||
``req.prepare(content_is_form=...)`` to finalise args from
|
||||
form-encoded bodies before dispatching.
|
||||
"""
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
obj = object.__new__(cls)
|
||||
obj._aiohttp_request = None
|
||||
obj.synapse_site = site
|
||||
obj.reactor = getattr(site, 'reactor', None)
|
||||
obj.our_server_name = our_server_name
|
||||
obj.our_server_name = our_server_name or getattr(site, 'server_name', 'test')
|
||||
|
||||
# Request properties — to be populated by test code
|
||||
obj.method = b"GET"
|
||||
obj.path = b"/"
|
||||
obj.uri = b"/"
|
||||
# --- Request properties ---
|
||||
obj.method = method
|
||||
obj.uri = uri # full URI including query string
|
||||
|
||||
# Split URI into path (no query) and query args
|
||||
uri_str = uri.decode("utf-8") if isinstance(uri, bytes) else uri
|
||||
parsed = urlparse(uri_str)
|
||||
obj.path = parsed.path.encode("utf-8")
|
||||
obj.clientproto = b"HTTP/1.1"
|
||||
obj.args = {}
|
||||
obj.content = io.BytesIO(b"")
|
||||
|
||||
# Parse query string into args (bytes-keyed, like Twisted)
|
||||
obj.args: dict[bytes, list[bytes]] = {}
|
||||
if parsed.query:
|
||||
for k, vs in parse_qs(parsed.query, keep_blank_values=True).items():
|
||||
bk = k.encode("utf-8")
|
||||
obj.args[bk] = [v.encode("utf-8") for v in vs]
|
||||
|
||||
obj.content = io.BytesIO(body)
|
||||
obj._client_ip = client_ip
|
||||
|
||||
# --- Headers ---
|
||||
obj.requestHeaders = ShimRequestHeaders(None)
|
||||
|
||||
# --- Response buffering ---
|
||||
obj.responseHeaders = ShimResponseHeaders()
|
||||
obj.code = 200
|
||||
obj.code_message = b"OK"
|
||||
@@ -500,9 +529,12 @@ class SynapseRequest:
|
||||
obj.sentLength = 0
|
||||
obj._response_buffer = bytearray()
|
||||
obj._disconnected = False
|
||||
|
||||
# --- X-Forwarded ---
|
||||
obj._forwarded_for = None
|
||||
obj._forwarded_https = False
|
||||
|
||||
# --- Synapse-specific ---
|
||||
global _next_request_seq
|
||||
obj.request_seq = _next_request_seq
|
||||
_next_request_seq += 1
|
||||
@@ -518,11 +550,61 @@ class SynapseRequest:
|
||||
obj._is_processing = False
|
||||
obj._processing_finished_time = None
|
||||
obj.finish_time = None
|
||||
obj.cookies = []
|
||||
obj.cookies: list[bytes] = []
|
||||
obj.channel = channel
|
||||
obj._client_ip = "127.0.0.1"
|
||||
return obj
|
||||
|
||||
def prepare_for_dispatch(self, content_is_form: bool = False) -> None:
|
||||
"""Finalise request state before dispatching to a handler.
|
||||
|
||||
Parses form-encoded POST bodies into ``self.args`` (matching Twisted's
|
||||
automatic behaviour) and sets up metrics / logging context.
|
||||
"""
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
# Merge form-encoded body into args
|
||||
if content_is_form:
|
||||
body = self.content.read()
|
||||
self.content.seek(0)
|
||||
if body:
|
||||
body_str = body.decode("utf-8") if isinstance(body, bytes) else body
|
||||
for k, vs in parse_qs(body_str, keep_blank_values=True).items():
|
||||
bk = k.encode("utf-8")
|
||||
self.args.setdefault(bk, []).extend(
|
||||
v.encode("utf-8") for v in vs
|
||||
)
|
||||
|
||||
# Set up metrics and logging context
|
||||
self.start_render(self.__class__.__name__)
|
||||
|
||||
def dispatch(self, root_resource: Any) -> None:
|
||||
"""Resolve the target resource and dispatch this request.
|
||||
|
||||
Walks the resource tree to find the handler, then kicks off
|
||||
the async render as an ``asyncio.Task``.
|
||||
"""
|
||||
target = _resolve_resource(root_resource, self.path)
|
||||
|
||||
if hasattr(target, '_async_render_wrapper'):
|
||||
self.render_deferred = asyncio.ensure_future(
|
||||
target._async_render_wrapper(self)
|
||||
)
|
||||
else:
|
||||
# Simple resource with render_GET/render_POST etc.
|
||||
method_str = self.method.decode('ascii') if isinstance(self.method, bytes) else self.method
|
||||
handler = getattr(target, 'render_' + method_str, None)
|
||||
if handler:
|
||||
result = handler(self)
|
||||
if asyncio.iscoroutine(result):
|
||||
self.render_deferred = asyncio.ensure_future(result)
|
||||
else:
|
||||
from synapse.http.server import respond_with_json
|
||||
respond_with_json(
|
||||
self, 404,
|
||||
{"errcode": "M_UNRECOGNIZED", "error": "Unrecognized request"},
|
||||
send_cors=True,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# X-Forwarded-For processing
|
||||
# ------------------------------------------------------------------
|
||||
@@ -778,6 +860,13 @@ class SynapseRequest:
|
||||
aiohttp shim it merely sets the ``finished`` flag — the actual
|
||||
``aiohttp.web.Response`` is built later by ``build_aiohttp_response()``.
|
||||
"""
|
||||
# Flush any cookies from self.cookies into Set-Cookie headers.
|
||||
# Twisted's Request.finish() did this automatically.
|
||||
for cookie_bytes in self.cookies:
|
||||
cookie_str = cookie_bytes.decode("utf-8") if isinstance(cookie_bytes, bytes) else cookie_bytes
|
||||
self.responseHeaders.addRawHeader("Set-Cookie", cookie_str)
|
||||
self.cookies.clear()
|
||||
|
||||
self.finish_time = time.time()
|
||||
self.finished = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user