make cookies work for SSO tests; move request management into the aiohttp shim

This commit is contained in:
Matthew Hodgson
2026-03-22 22:28:54 +00:00
parent 21544b37ed
commit 4f9e8b82bc
2 changed files with 19 additions and 71 deletions

View File

@@ -482,36 +482,21 @@ def make_request(
channel = FakeChannel(site, reactor, ip=client_ip, clock=clock)
# Use the shim's for_testing constructor
# Create the request via the shim — it handles path/query splitting,
# query arg parsing, etc.
from synapse.http.aiohttp_shim import SynapseRequest as ShimRequest
req = ShimRequest.for_testing(
channel,
site,
our_server_name="test_server",
method=method,
uri=path,
body=content,
client_ip=client_ip,
max_request_body_size=MAX_REQUEST_SIZE,
)
channel.request = req
req.method = method
# URI is the full path+query; path is just the path part (no query string).
# Twisted's Request.path was always without query string.
req.uri = path
from urllib.parse import parse_qs, urlparse
path_str = path.decode("utf-8") if isinstance(path, bytes) else path
parsed = urlparse(path_str)
req.path = parsed.path.encode("utf-8") if isinstance(path, bytes) else parsed.path.encode("utf-8")
req.content = BytesIO(content)
req.content.seek(0)
req._client_ip = client_ip
# Parse query string into args
if parsed.query:
for k, vs in parse_qs(parsed.query, keep_blank_values=True).items():
bk = k.encode("utf-8") if isinstance(k, str) else k
req.args[bk] = [v.encode("utf-8") if isinstance(v, str) else v for v in vs]
# Add standard headers
# Add standard headers (test-specific, not part of the shim)
if custom_headers is None or not any(
(k if isinstance(k, bytes) else k.encode("ascii")) == b"Content-Length"
for k, _ in custom_headers
@@ -545,58 +530,13 @@ def make_request(
for k, v in custom_headers:
req.requestHeaders.addRawHeader(k, v)
# Initialize request metrics and logcontext before dispatch
import asyncio
from synapse.http.request_metrics import RequestMetrics
from synapse.logging.context import ContextRequest, LoggingContext
req.start_time = time.time()
server_name = getattr(site, 'server_name', 'test')
req.request_metrics = RequestMetrics(our_server_name=server_name)
req.request_metrics.start(req.start_time, name="test", method=req.get_method())
# Create a ContextRequest (NOT the SynapseRequest itself!) for the LoggingContext
context_request = ContextRequest(
request_id=req.get_request_id(),
ip_address=req.getClientIP(),
site_tag=getattr(site, 'site_tag', 'test'),
requester=None,
authenticated_entity=None,
method=req.get_method(),
url=req.get_redacted_uri(),
protocol="HTTP/1.1",
user_agent="",
)
req.logcontext = LoggingContext(
name="test-%s-%s" % (req.get_method(), req.get_redacted_uri()),
server_name=server_name,
request=context_request,
)
# Dispatch the request through the resource tree using the same
# logic as the production aiohttp handler.
from synapse.http.aiohttp_shim import _resolve_resource
# Finalise request state (form body parsing, metrics, logcontext)
# and dispatch through the resource tree.
req.prepare_for_dispatch(content_is_form=content_is_form)
root_resource = getattr(site, 'resource', None) or getattr(site, '_resource', None)
if root_resource is not None:
target = _resolve_resource(root_resource, path)
if hasattr(target, '_async_render_wrapper'):
req.render_deferred = asyncio.ensure_future(
target._async_render_wrapper(req)
)
else:
# Simple resource with render_GET/render_POST etc
method_str = req.method.decode('ascii') if isinstance(req.method, bytes) else req.method
method_name = 'render_' + method_str
handler = getattr(target, method_name, None)
if handler:
result = handler(req)
if asyncio.iscoroutine(result):
req.render_deferred = asyncio.ensure_future(result)
else:
from synapse.http.server import respond_with_json
respond_with_json(req, 404, {"errcode": "M_UNRECOGNIZED", "error": "Unrecognized request"}, send_cors=True)
req.dispatch(root_resource)
if await_result:
channel.await_result()

View File

@@ -368,6 +368,14 @@ class TestCase(_stdlib_unittest.TestCase):
f"(difference: {abs(first - second)!r})"
)
def assertSubstring(self, substring: str, text: str) -> None:
"""Assert that substring is found in text.
Replacement for twisted.trial.unittest.TestCase.assertSubstring.
"""
if substring not in text:
self.fail(f"{substring!r} not found in {text!r}")
def assertNoResult(self, d: "Deferred[Any]") -> None:
"""Assert that a Deferred has not yet fired.