mirror of
https://github.com/element-hq/synapse.git
synced 2026-03-29 15:20:16 +00:00
make cookies work for SSO tests; move request management into the aiohttp shim
This commit is contained in:
@@ -482,36 +482,21 @@ def make_request(
|
||||
|
||||
channel = FakeChannel(site, reactor, ip=client_ip, clock=clock)
|
||||
|
||||
# Use the shim's for_testing constructor
|
||||
# Create the request via the shim — it handles path/query splitting,
|
||||
# query arg parsing, etc.
|
||||
from synapse.http.aiohttp_shim import SynapseRequest as ShimRequest
|
||||
req = ShimRequest.for_testing(
|
||||
channel,
|
||||
site,
|
||||
our_server_name="test_server",
|
||||
method=method,
|
||||
uri=path,
|
||||
body=content,
|
||||
client_ip=client_ip,
|
||||
max_request_body_size=MAX_REQUEST_SIZE,
|
||||
)
|
||||
channel.request = req
|
||||
|
||||
req.method = method
|
||||
# URI is the full path+query; path is just the path part (no query string).
|
||||
# Twisted's Request.path was always without query string.
|
||||
req.uri = path
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
path_str = path.decode("utf-8") if isinstance(path, bytes) else path
|
||||
parsed = urlparse(path_str)
|
||||
req.path = parsed.path.encode("utf-8") if isinstance(path, bytes) else parsed.path.encode("utf-8")
|
||||
|
||||
req.content = BytesIO(content)
|
||||
req.content.seek(0)
|
||||
req._client_ip = client_ip
|
||||
|
||||
# Parse query string into args
|
||||
if parsed.query:
|
||||
for k, vs in parse_qs(parsed.query, keep_blank_values=True).items():
|
||||
bk = k.encode("utf-8") if isinstance(k, str) else k
|
||||
req.args[bk] = [v.encode("utf-8") if isinstance(v, str) else v for v in vs]
|
||||
|
||||
# Add standard headers
|
||||
# Add standard headers (test-specific, not part of the shim)
|
||||
if custom_headers is None or not any(
|
||||
(k if isinstance(k, bytes) else k.encode("ascii")) == b"Content-Length"
|
||||
for k, _ in custom_headers
|
||||
@@ -545,58 +530,13 @@ def make_request(
|
||||
for k, v in custom_headers:
|
||||
req.requestHeaders.addRawHeader(k, v)
|
||||
|
||||
# Initialize request metrics and logcontext before dispatch
|
||||
import asyncio
|
||||
from synapse.http.request_metrics import RequestMetrics
|
||||
from synapse.logging.context import ContextRequest, LoggingContext
|
||||
|
||||
req.start_time = time.time()
|
||||
server_name = getattr(site, 'server_name', 'test')
|
||||
req.request_metrics = RequestMetrics(our_server_name=server_name)
|
||||
req.request_metrics.start(req.start_time, name="test", method=req.get_method())
|
||||
|
||||
# Create a ContextRequest (NOT the SynapseRequest itself!) for the LoggingContext
|
||||
context_request = ContextRequest(
|
||||
request_id=req.get_request_id(),
|
||||
ip_address=req.getClientIP(),
|
||||
site_tag=getattr(site, 'site_tag', 'test'),
|
||||
requester=None,
|
||||
authenticated_entity=None,
|
||||
method=req.get_method(),
|
||||
url=req.get_redacted_uri(),
|
||||
protocol="HTTP/1.1",
|
||||
user_agent="",
|
||||
)
|
||||
req.logcontext = LoggingContext(
|
||||
name="test-%s-%s" % (req.get_method(), req.get_redacted_uri()),
|
||||
server_name=server_name,
|
||||
request=context_request,
|
||||
)
|
||||
|
||||
# Dispatch the request through the resource tree using the same
|
||||
# logic as the production aiohttp handler.
|
||||
from synapse.http.aiohttp_shim import _resolve_resource
|
||||
# Finalise request state (form body parsing, metrics, logcontext)
|
||||
# and dispatch through the resource tree.
|
||||
req.prepare_for_dispatch(content_is_form=content_is_form)
|
||||
|
||||
root_resource = getattr(site, 'resource', None) or getattr(site, '_resource', None)
|
||||
if root_resource is not None:
|
||||
target = _resolve_resource(root_resource, path)
|
||||
|
||||
if hasattr(target, '_async_render_wrapper'):
|
||||
req.render_deferred = asyncio.ensure_future(
|
||||
target._async_render_wrapper(req)
|
||||
)
|
||||
else:
|
||||
# Simple resource with render_GET/render_POST etc
|
||||
method_str = req.method.decode('ascii') if isinstance(req.method, bytes) else req.method
|
||||
method_name = 'render_' + method_str
|
||||
handler = getattr(target, method_name, None)
|
||||
if handler:
|
||||
result = handler(req)
|
||||
if asyncio.iscoroutine(result):
|
||||
req.render_deferred = asyncio.ensure_future(result)
|
||||
else:
|
||||
from synapse.http.server import respond_with_json
|
||||
respond_with_json(req, 404, {"errcode": "M_UNRECOGNIZED", "error": "Unrecognized request"}, send_cors=True)
|
||||
req.dispatch(root_resource)
|
||||
|
||||
if await_result:
|
||||
channel.await_result()
|
||||
|
||||
@@ -368,6 +368,14 @@ class TestCase(_stdlib_unittest.TestCase):
|
||||
f"(difference: {abs(first - second)!r})"
|
||||
)
|
||||
|
||||
def assertSubstring(self, substring: str, text: str) -> None:
|
||||
"""Assert that substring is found in text.
|
||||
|
||||
Replacement for twisted.trial.unittest.TestCase.assertSubstring.
|
||||
"""
|
||||
if substring not in text:
|
||||
self.fail(f"{substring!r} not found in {text!r}")
|
||||
|
||||
def assertNoResult(self, d: "Deferred[Any]") -> None:
|
||||
"""Assert that a Deferred has not yet fired.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user