mirror of
https://github.com/element-hq/synapse.git
synced 2026-03-30 23:45:43 +00:00
convert login tests to asyncio and make them work with faketime
This commit is contained in:
@@ -206,11 +206,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_POST_ratelimiting_per_address(self) -> None:
|
||||
async def test_POST_ratelimiting_per_address(self) -> None:
|
||||
# Create different users so we're sure not to be bothered by the per-user
|
||||
# ratelimiter.
|
||||
for i in range(6):
|
||||
self.register_user("kermit" + str(i), "monkey")
|
||||
await self.register_user("kermit" + str(i), "monkey")
|
||||
|
||||
for i in range(6):
|
||||
params = {
|
||||
@@ -218,7 +218,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
|
||||
"password": "monkey",
|
||||
}
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
channel = await self.make_request(b"POST", LOGIN_URL, params)
|
||||
|
||||
if i == 5:
|
||||
self.assertEqual(channel.code, 429, msg=channel.result)
|
||||
@@ -240,7 +240,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
|
||||
"password": "monkey",
|
||||
}
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
channel = await self.make_request(b"POST", LOGIN_URL, params)
|
||||
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
|
||||
@@ -257,8 +257,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_POST_ratelimiting_per_account(self) -> None:
|
||||
self.register_user("kermit", "monkey")
|
||||
async def test_POST_ratelimiting_per_account(self) -> None:
|
||||
await self.register_user("kermit", "monkey")
|
||||
|
||||
for i in range(6):
|
||||
params = {
|
||||
@@ -266,7 +266,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||
"password": "monkey",
|
||||
}
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
channel = await self.make_request(b"POST", LOGIN_URL, params)
|
||||
|
||||
if i == 5:
|
||||
self.assertEqual(channel.code, 429, msg=channel.result)
|
||||
@@ -288,7 +288,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||
"password": "monkey",
|
||||
}
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
channel = await self.make_request(b"POST", LOGIN_URL, params)
|
||||
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
|
||||
@@ -305,8 +305,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
|
||||
self.register_user("kermit", "monkey")
|
||||
async def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
|
||||
await self.register_user("kermit", "monkey")
|
||||
|
||||
for i in range(6):
|
||||
params = {
|
||||
@@ -314,7 +314,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||
"password": "notamonkey",
|
||||
}
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
channel = await self.make_request(b"POST", LOGIN_URL, params)
|
||||
|
||||
if i == 5:
|
||||
self.assertEqual(channel.code, 429, msg=channel.result)
|
||||
@@ -336,7 +336,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||
"password": "notamonkey",
|
||||
}
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
channel = await self.make_request(b"POST", LOGIN_URL, params)
|
||||
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
|
||||
@@ -430,57 +430,57 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
@override_config({"session_lifetime": "24h"})
|
||||
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
|
||||
self.register_user("kermit", "monkey")
|
||||
async def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
|
||||
await self.register_user("kermit", "monkey")
|
||||
|
||||
# log in as normal
|
||||
access_token = self.login("kermit", "monkey")
|
||||
access_token = await self.login("kermit", "monkey")
|
||||
|
||||
# we should now be able to make requests with the access token
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# time passes
|
||||
self.reactor.advance(24 * 3600)
|
||||
|
||||
# ... and we should be soft-logouted
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||
self.assertEqual(channel.json_body["soft_logout"], True)
|
||||
|
||||
# Now try to hard logout this session
|
||||
channel = self.make_request(b"POST", "/logout", access_token=access_token)
|
||||
channel = await self.make_request(b"POST", "/logout", access_token=access_token)
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
|
||||
@override_config({"session_lifetime": "24h"})
|
||||
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
|
||||
async def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
|
||||
self,
|
||||
) -> None:
|
||||
self.register_user("kermit", "monkey")
|
||||
await self.register_user("kermit", "monkey")
|
||||
|
||||
# log in as normal
|
||||
access_token = self.login("kermit", "monkey")
|
||||
access_token = await self.login("kermit", "monkey")
|
||||
|
||||
# we should now be able to make requests with the access token
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# time passes
|
||||
self.reactor.advance(24 * 3600)
|
||||
|
||||
# ... and we should be soft-logouted
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||
self.assertEqual(channel.json_body["soft_logout"], True)
|
||||
|
||||
# Now try to hard log out all of the user's sessions
|
||||
channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
|
||||
channel = await self.make_request(b"POST", "/logout/all", access_token=access_token)
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
|
||||
def test_login_with_overly_long_device_id_fails(self) -> None:
|
||||
self.register_user("mickey", "cheese")
|
||||
async def test_login_with_overly_long_device_id_fails(self) -> None:
|
||||
await self.register_user("mickey", "cheese")
|
||||
|
||||
# create a device_id longer than 512 characters
|
||||
device_id = "yolo" * 512
|
||||
@@ -493,7 +493,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
}
|
||||
|
||||
# make a login request with the bad device_id
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/v3/login",
|
||||
body,
|
||||
@@ -514,8 +514,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_require_approval(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_require_approval(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"POST",
|
||||
"register",
|
||||
{
|
||||
@@ -535,25 +535,25 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||
"password": "monkey",
|
||||
}
|
||||
channel = self.make_request("POST", LOGIN_URL, params)
|
||||
channel = await self.make_request("POST", LOGIN_URL, params)
|
||||
self.assertEqual(403, channel.code, channel.result)
|
||||
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
|
||||
self.assertEqual(
|
||||
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
|
||||
)
|
||||
|
||||
def test_get_login_flows_with_login_via_existing_disabled(self) -> None:
|
||||
async def test_get_login_flows_with_login_via_existing_disabled(self) -> None:
|
||||
"""GET /login should return m.login.token without get_login_token"""
|
||||
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||
channel = await self.make_request("GET", "/_matrix/client/r0/login")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
flows = {flow["type"]: flow for flow in channel.json_body["flows"]}
|
||||
self.assertNotIn("m.login.token", flows)
|
||||
|
||||
@override_config({"login_via_existing_session": {"enabled": True}})
|
||||
def test_get_login_flows_with_login_via_existing_enabled(self) -> None:
|
||||
async def test_get_login_flows_with_login_via_existing_enabled(self) -> None:
|
||||
"""GET /login should return m.login.token with get_login_token true"""
|
||||
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||
channel = await self.make_request("GET", "/_matrix/client/r0/login")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
self.assertCountEqual(
|
||||
@@ -576,13 +576,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
]
|
||||
}
|
||||
)
|
||||
def test_spam_checker_allow(self) -> None:
|
||||
async def test_spam_checker_allow(self) -> None:
|
||||
"""Check that that adding a spam checker doesn't break login."""
|
||||
self.register_user("kermit", "monkey")
|
||||
await self.register_user("kermit", "monkey")
|
||||
|
||||
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/login",
|
||||
body,
|
||||
@@ -600,14 +600,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
]
|
||||
}
|
||||
)
|
||||
def test_spam_checker_deny(self) -> None:
|
||||
async def test_spam_checker_deny(self) -> None:
|
||||
"""Check that login"""
|
||||
|
||||
self.register_user("kermit", "monkey")
|
||||
await self.register_user("kermit", "monkey")
|
||||
|
||||
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/login",
|
||||
body,
|
||||
@@ -677,9 +677,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
d.update(build_synapse_client_resource_tree(self.hs))
|
||||
return d
|
||||
|
||||
def test_get_login_flows(self) -> None:
|
||||
async def test_get_login_flows(self) -> None:
|
||||
"""GET /login should return password and SSO flows"""
|
||||
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||
channel = await self.make_request("GET", "/_matrix/client/r0/login")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
expected_flow_types = [
|
||||
@@ -704,17 +704,17 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_multi_sso_redirect(self) -> None:
|
||||
async def test_multi_sso_redirect(self) -> None:
|
||||
"""/login/sso/redirect should redirect to an identity picker"""
|
||||
# first hit the redirect url, which should redirect to our idp picker
|
||||
channel = self._make_sso_redirect_request(None)
|
||||
channel = await self._make_sso_redirect_request(None)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
uri = location_headers[0]
|
||||
|
||||
# hitting that picker should give us some HTML
|
||||
channel = self.make_request("GET", uri)
|
||||
channel = await self.make_request("GET", uri)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# parse the form to check it has fields assumed elsewhere in this class
|
||||
@@ -734,10 +734,10 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
|
||||
|
||||
def test_multi_sso_redirect_to_cas(self) -> None:
|
||||
async def test_multi_sso_redirect_to_cas(self) -> None:
|
||||
"""If CAS is chosen, should redirect to the CAS server"""
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
"/_synapse/client/pick_idp?redirectUrl="
|
||||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
||||
@@ -758,7 +758,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
# follow the redirect
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
# We have to make this relative to be compatible with `make_request(...)`
|
||||
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
|
||||
@@ -787,9 +787,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
service_uri_params = urllib.parse.parse_qs(service_uri_query)
|
||||
self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
|
||||
|
||||
def test_multi_sso_redirect_to_saml(self) -> None:
|
||||
async def test_multi_sso_redirect_to_saml(self) -> None:
|
||||
"""If SAML is chosen, should redirect to the SAML server"""
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
"/_synapse/client/pick_idp?redirectUrl="
|
||||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
||||
@@ -809,7 +809,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
# follow the redirect
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
# We have to make this relative to be compatible with `make_request(...)`
|
||||
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
|
||||
@@ -835,14 +835,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
relay_state_param = saml_uri_params["RelayState"][0]
|
||||
self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
|
||||
|
||||
def test_login_via_oidc(self) -> None:
|
||||
async def test_login_via_oidc(self) -> None:
|
||||
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
|
||||
|
||||
fake_oidc_server = self.helper.fake_oidc_server()
|
||||
|
||||
with fake_oidc_server.patch_homeserver(hs=self.hs):
|
||||
# pick the default OIDC provider
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/_synapse/client/pick_idp?redirectUrl={urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)}&idp=oidc",
|
||||
)
|
||||
@@ -861,7 +861,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
with fake_oidc_server.patch_homeserver(hs=self.hs):
|
||||
# follow the redirect
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
# We have to make this relative to be compatible with `make_request(...)`
|
||||
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
|
||||
@@ -897,7 +897,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
TEST_CLIENT_REDIRECT_URL,
|
||||
)
|
||||
|
||||
channel, _ = self.helper.complete_oidc_auth(
|
||||
channel, _ = await self.helper.complete_oidc_auth(
|
||||
fake_oidc_server, oidc_uri, cookies, {"sub": "user1"}
|
||||
)
|
||||
|
||||
@@ -925,7 +925,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
# finally, submit the matrix login token to the login API, which gives us our
|
||||
# matrix access token, mxid, and device id.
|
||||
login_token = params[2][1]
|
||||
chan = self.make_request(
|
||||
chan = await self.make_request(
|
||||
"POST",
|
||||
"/login",
|
||||
content={"type": "m.login.token", "token": login_token},
|
||||
@@ -933,15 +933,15 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(chan.code, 200, chan.result)
|
||||
self.assertEqual(chan.json_body["user_id"], "@user1:test")
|
||||
|
||||
def test_multi_sso_redirect_unknown_idp(self) -> None:
|
||||
async def test_multi_sso_redirect_unknown_idp(self) -> None:
|
||||
"""An unknown IdP should cause a 400 bad request error"""
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
|
||||
def test_multi_sso_redirect_unknown_idp_as_url(self) -> None:
|
||||
async def test_multi_sso_redirect_unknown_idp_as_url(self) -> None:
|
||||
"""
|
||||
An unknown IdP that looks like a URL should cause a 400 bad request error (to
|
||||
avoid open redirects).
|
||||
@@ -954,24 +954,24 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
in the URL building tests to cover this case but is only a unit test vs
|
||||
something at the REST layer here that covers things end-to-end.
|
||||
"""
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
"/_synapse/client/pick_idp?redirectUrl=something&idp=https://element.io/",
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
|
||||
def test_client_idp_redirect_to_unknown(self) -> None:
|
||||
async def test_client_idp_redirect_to_unknown(self) -> None:
|
||||
"""If the client tries to pick an unknown IdP, return a 404"""
|
||||
channel = self._make_sso_redirect_request("xxx")
|
||||
channel = await self._make_sso_redirect_request("xxx")
|
||||
self.assertEqual(channel.code, 404, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||
|
||||
def test_client_idp_redirect_to_oidc(self) -> None:
|
||||
async def test_client_idp_redirect_to_oidc(self) -> None:
|
||||
"""If the client pick a known IdP, redirect to it"""
|
||||
fake_oidc_server = self.helper.fake_oidc_server()
|
||||
|
||||
with fake_oidc_server.patch_homeserver(hs=self.hs):
|
||||
channel = self._make_sso_redirect_request("oidc")
|
||||
channel = await self._make_sso_redirect_request("oidc")
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
@@ -981,7 +981,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
# it should redirect us to the auth page of the OIDC server
|
||||
self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
|
||||
|
||||
def _make_sso_redirect_request(self, idp_prov: str | None = None) -> FakeChannel:
|
||||
async def _make_sso_redirect_request(self, idp_prov: str | None = None) -> FakeChannel:
|
||||
"""Send a request to /_matrix/client/r0/login/sso/redirect
|
||||
|
||||
... possibly specifying an IDP provider
|
||||
@@ -991,7 +991,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
endpoint += "/" + idp_prov
|
||||
endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
||||
|
||||
return self.make_request(
|
||||
return await self.make_request(
|
||||
"GET",
|
||||
endpoint,
|
||||
custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)],
|
||||
@@ -1011,7 +1011,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.base_url = "https://matrix.goodserver.com/"
|
||||
self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
|
||||
|
||||
@@ -1054,7 +1054,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||
mocked_http_client = Mock(spec=["get_raw"])
|
||||
mocked_http_client.get_raw.side_effect = get_raw
|
||||
|
||||
self.hs = self.setup_test_homeserver(
|
||||
self.hs = await self.setup_test_homeserver(
|
||||
config=config,
|
||||
proxied_http_client=mocked_http_client,
|
||||
)
|
||||
@@ -1064,7 +1064,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||
|
||||
def test_cas_redirect_confirm(self) -> None:
|
||||
async def test_cas_redirect_confirm(self) -> None:
|
||||
"""Tests that the SSO login flow serves a confirmation page before redirecting a
|
||||
user to the redirect URL.
|
||||
"""
|
||||
@@ -1079,7 +1079,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||
cas_ticket_url = urllib.parse.urlunparse(url_parts)
|
||||
|
||||
# Get Synapse to call the fake CAS and serve the template.
|
||||
channel = self.make_request("GET", cas_ticket_url)
|
||||
channel = await self.make_request("GET", cas_ticket_url)
|
||||
|
||||
# Test that the response is HTML.
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
@@ -1105,15 +1105,15 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_cas_redirect_whitelisted(self) -> None:
|
||||
async def test_cas_redirect_whitelisted(self) -> None:
|
||||
"""Tests that the SSO login flow serves a redirect to a whitelisted url"""
|
||||
self._test_redirect("https://legit-site.com/")
|
||||
await self._test_redirect("https://legit-site.com/")
|
||||
|
||||
@override_config({"public_baseurl": "https://example.com"})
|
||||
def test_cas_redirect_login_fallback(self) -> None:
|
||||
self._test_redirect("https://example.com/_matrix/static/client/login")
|
||||
async def test_cas_redirect_login_fallback(self) -> None:
|
||||
await self._test_redirect("https://example.com/_matrix/static/client/login")
|
||||
|
||||
def _test_redirect(self, redirect_url: str) -> None:
|
||||
async def _test_redirect(self, redirect_url: str) -> None:
|
||||
"""Tests that the SSO login flow serves a redirect for the given redirect URL."""
|
||||
cas_ticket_url = (
|
||||
"/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
|
||||
@@ -1121,7 +1121,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
# Get Synapse to call the fake CAS and serve the template.
|
||||
channel = self.make_request("GET", cas_ticket_url)
|
||||
channel = await self.make_request("GET", cas_ticket_url)
|
||||
|
||||
self.assertEqual(channel.code, 302)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
@@ -1129,15 +1129,15 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
|
||||
|
||||
@override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
|
||||
def test_deactivated_user(self) -> None:
|
||||
async def test_deactivated_user(self) -> None:
|
||||
"""Logging in as a deactivated account should error."""
|
||||
redirect_url = "https://legit-site.com/"
|
||||
|
||||
# First login (to create the user).
|
||||
self._test_redirect(redirect_url)
|
||||
await self._test_redirect(redirect_url)
|
||||
|
||||
# Deactivate the account.
|
||||
self.get_success(
|
||||
await self.get_success(
|
||||
self.deactivate_account_handler.deactivate_account(
|
||||
self.user_id, False, create_requester(self.user_id)
|
||||
)
|
||||
@@ -1150,7 +1150,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
# Get Synapse to call the fake CAS and serve the template.
|
||||
channel = self.make_request("GET", cas_ticket_url)
|
||||
channel = await self.make_request("GET", cas_ticket_url)
|
||||
|
||||
# Because the user is deactivated they are served an error template.
|
||||
self.assertEqual(channel.code, 403)
|
||||
@@ -1187,24 +1187,24 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
result: bytes = jwt.encode(header, payload, secret)
|
||||
return result.decode("ascii")
|
||||
|
||||
def jwt_login(self, *args: Any) -> FakeChannel:
|
||||
async def jwt_login(self, *args: Any) -> FakeChannel:
|
||||
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
channel = await self.make_request(b"POST", LOGIN_URL, params)
|
||||
return channel
|
||||
|
||||
def test_login_jwt_valid_registered(self) -> None:
|
||||
self.register_user("kermit", "monkey")
|
||||
channel = self.jwt_login({"sub": "kermit"})
|
||||
async def test_login_jwt_valid_registered(self) -> None:
|
||||
await self.register_user("kermit", "monkey")
|
||||
channel = await self.jwt_login({"sub": "kermit"})
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||
|
||||
def test_login_jwt_valid_unregistered(self) -> None:
|
||||
channel = self.jwt_login({"sub": "frog"})
|
||||
async def test_login_jwt_valid_unregistered(self) -> None:
|
||||
channel = await self.jwt_login({"sub": "frog"})
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["user_id"], "@frog:test")
|
||||
|
||||
def test_login_jwt_invalid_signature(self) -> None:
|
||||
channel = self.jwt_login({"sub": "frog"}, "notsecret")
|
||||
async def test_login_jwt_invalid_signature(self) -> None:
|
||||
channel = await self.jwt_login({"sub": "frog"}, "notsecret")
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(
|
||||
@@ -1212,8 +1212,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
"JWT validation failed: Signature verification failed",
|
||||
)
|
||||
|
||||
def test_login_jwt_expired(self) -> None:
|
||||
channel = self.jwt_login({"sub": "frog", "exp": 864000})
|
||||
async def test_login_jwt_expired(self) -> None:
|
||||
channel = await self.jwt_login({"sub": "frog", "exp": 864000})
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(
|
||||
@@ -1221,9 +1221,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
"JWT validation failed: expired_token: The token is expired",
|
||||
)
|
||||
|
||||
def test_login_jwt_not_before(self) -> None:
|
||||
async def test_login_jwt_not_before(self) -> None:
|
||||
now = int(time.time())
|
||||
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
|
||||
channel = await self.jwt_login({"sub": "frog", "nbf": now + 3600})
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(
|
||||
@@ -1231,22 +1231,22 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
"JWT validation failed: invalid_token: The token is not valid yet",
|
||||
)
|
||||
|
||||
def test_login_no_sub(self) -> None:
|
||||
channel = self.jwt_login({"username": "root"})
|
||||
async def test_login_no_sub(self) -> None:
|
||||
channel = await self.jwt_login({"username": "root"})
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(channel.json_body["error"], "Invalid JWT")
|
||||
|
||||
@override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
|
||||
def test_login_iss(self) -> None:
|
||||
async def test_login_iss(self) -> None:
|
||||
"""Test validating the issuer claim."""
|
||||
# A valid issuer.
|
||||
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
|
||||
channel = await self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||
|
||||
# An invalid issuer.
|
||||
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
|
||||
channel = await self.jwt_login({"sub": "kermit", "iss": "invalid"})
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertRegex(
|
||||
@@ -1255,7 +1255,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
# Not providing an issuer.
|
||||
channel = self.jwt_login({"sub": "kermit"})
|
||||
channel = await self.jwt_login({"sub": "kermit"})
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertRegex(
|
||||
@@ -1263,22 +1263,22 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
r"^JWT validation failed: missing_claim: Missing [\"']iss[\"'] claim$",
|
||||
)
|
||||
|
||||
def test_login_iss_no_config(self) -> None:
|
||||
async def test_login_iss_no_config(self) -> None:
|
||||
"""Test providing an issuer claim without requiring it in the configuration."""
|
||||
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
|
||||
channel = await self.jwt_login({"sub": "kermit", "iss": "invalid"})
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||
|
||||
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
|
||||
def test_login_aud(self) -> None:
|
||||
async def test_login_aud(self) -> None:
|
||||
"""Test validating the audience claim."""
|
||||
# A valid audience.
|
||||
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
|
||||
channel = await self.jwt_login({"sub": "kermit", "aud": "test-audience"})
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||
|
||||
# An invalid audience.
|
||||
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
|
||||
channel = await self.jwt_login({"sub": "kermit", "aud": "invalid"})
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertRegex(
|
||||
@@ -1287,7 +1287,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
# Not providing an audience.
|
||||
channel = self.jwt_login({"sub": "kermit"})
|
||||
channel = await self.jwt_login({"sub": "kermit"})
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertRegex(
|
||||
@@ -1295,9 +1295,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
r"^JWT validation failed: missing_claim: Missing [\"']aud[\"'] claim$",
|
||||
)
|
||||
|
||||
def test_login_aud_no_config(self) -> None:
|
||||
async def test_login_aud_no_config(self) -> None:
|
||||
"""Test providing an audience without requiring it in the configuration."""
|
||||
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
|
||||
channel = await self.jwt_login({"sub": "kermit", "aud": "invalid"})
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertRegex(
|
||||
@@ -1305,36 +1305,36 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
r"^JWT validation failed: invalid_claim: Invalid claim [\"']aud[\"']$",
|
||||
)
|
||||
|
||||
def test_login_default_sub(self) -> None:
|
||||
async def test_login_default_sub(self) -> None:
|
||||
"""Test reading user ID from the default subject claim."""
|
||||
channel = self.jwt_login({"sub": "kermit"})
|
||||
channel = await self.jwt_login({"sub": "kermit"})
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||
|
||||
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
|
||||
def test_login_custom_sub(self) -> None:
|
||||
async def test_login_custom_sub(self) -> None:
|
||||
"""Test reading user ID from a custom subject claim."""
|
||||
channel = self.jwt_login({"username": "frog"})
|
||||
channel = await self.jwt_login({"username": "frog"})
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["user_id"], "@frog:test")
|
||||
|
||||
@override_config(
|
||||
{"jwt_config": {**base_config, "display_name_claim": "display_name"}}
|
||||
)
|
||||
def test_login_custom_display_name(self) -> None:
|
||||
async def test_login_custom_display_name(self) -> None:
|
||||
"""Test setting a custom display name."""
|
||||
localpart = "pinkie"
|
||||
user_id = f"@{localpart}:test"
|
||||
display_name = "Pinkie Pie"
|
||||
|
||||
# Perform the login, specifying a custom display name.
|
||||
channel = self.jwt_login({"sub": localpart, "display_name": display_name})
|
||||
channel = await self.jwt_login({"sub": localpart, "display_name": display_name})
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["user_id"], user_id)
|
||||
|
||||
# Fetch the user's display name and check that it was set correctly.
|
||||
access_token = channel.json_body["access_token"]
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{user_id}/displayname",
|
||||
access_token=access_token,
|
||||
@@ -1342,23 +1342,23 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["displayname"], display_name)
|
||||
|
||||
def test_login_no_token(self) -> None:
|
||||
async def test_login_no_token(self) -> None:
|
||||
params = {"type": "org.matrix.login.jwt"}
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
channel = await self.make_request(b"POST", LOGIN_URL, params)
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
|
||||
|
||||
def test_deactivated_user(self) -> None:
|
||||
async def test_deactivated_user(self) -> None:
|
||||
"""Logging in as a deactivated account should error."""
|
||||
user_id = self.register_user("kermit", "monkey")
|
||||
self.get_success(
|
||||
user_id = await self.register_user("kermit", "monkey")
|
||||
await self.get_success(
|
||||
self.hs.get_deactivate_account_handler().deactivate_account(
|
||||
user_id, erase_data=False, requester=create_requester(user_id)
|
||||
)
|
||||
)
|
||||
|
||||
channel = self.jwt_login({"sub": "kermit"})
|
||||
channel = await self.jwt_login({"sub": "kermit"})
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_USER_DEACTIVATED")
|
||||
self.assertEqual(
|
||||
@@ -1436,18 +1436,18 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
|
||||
result: bytes = jwt.encode(header, payload, secret)
|
||||
return result.decode("ascii")
|
||||
|
||||
def jwt_login(self, *args: Any) -> FakeChannel:
|
||||
async def jwt_login(self, *args: Any) -> FakeChannel:
|
||||
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
channel = await self.make_request(b"POST", LOGIN_URL, params)
|
||||
return channel
|
||||
|
||||
def test_login_jwt_valid(self) -> None:
|
||||
channel = self.jwt_login({"sub": "kermit"})
|
||||
async def test_login_jwt_valid(self) -> None:
|
||||
channel = await self.jwt_login({"sub": "kermit"})
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||
|
||||
def test_login_jwt_invalid_signature(self) -> None:
|
||||
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
|
||||
async def test_login_jwt_invalid_signature(self) -> None:
|
||||
channel = await self.jwt_login({"sub": "frog"}, self.bad_privatekey)
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(
|
||||
@@ -1465,8 +1465,8 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
register.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.hs = self.setup_test_homeserver()
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.hs = await self.setup_test_homeserver()
|
||||
|
||||
self.service = ApplicationService(
|
||||
id="unique_identifier",
|
||||
@@ -1511,23 +1511,23 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
self.hs.get_datastores().main.services_cache.append(self.msc4190_service)
|
||||
return self.hs
|
||||
|
||||
def test_login_appservice_user(self) -> None:
|
||||
async def test_login_appservice_user(self) -> None:
|
||||
"""Test that an appservice user can use /login"""
|
||||
self.register_appservice_user(AS_USER, self.service.token)
|
||||
await self.register_appservice_user(AS_USER, self.service.token)
|
||||
|
||||
params = {
|
||||
"type": login.LoginRestServlet.APPSERVICE_TYPE,
|
||||
"identifier": {"type": "m.id.user", "user": AS_USER},
|
||||
}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
b"POST", LOGIN_URL, params, access_token=self.service.token
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
|
||||
def test_login_appservice_msc4190_fail(self) -> None:
|
||||
async def test_login_appservice_msc4190_fail(self) -> None:
|
||||
"""Test that an appservice user can use /login"""
|
||||
self.register_appservice_user(
|
||||
await self.register_appservice_user(
|
||||
"as3_user_alice", self.msc4190_service.token, inhibit_login=True
|
||||
)
|
||||
|
||||
@@ -1535,7 +1535,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"type": login.LoginRestServlet.APPSERVICE_TYPE,
|
||||
"identifier": {"type": "m.id.user", "user": "as3_user_alice"},
|
||||
}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
b"POST", LOGIN_URL, params, access_token=self.msc4190_service.token
|
||||
)
|
||||
|
||||
@@ -1546,9 +1546,9 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
channel.json_body,
|
||||
)
|
||||
|
||||
def test_login_appservice_user_bot(self) -> None:
|
||||
async def test_login_appservice_user_bot(self) -> None:
|
||||
"""Test that the appservice bot can use /login"""
|
||||
self.register_appservice_user(AS_USER, self.service.token)
|
||||
await self.register_appservice_user(AS_USER, self.service.token)
|
||||
|
||||
params = {
|
||||
"type": login.LoginRestServlet.APPSERVICE_TYPE,
|
||||
@@ -1557,51 +1557,51 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"user": self.service.sender.to_string(),
|
||||
},
|
||||
}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
b"POST", LOGIN_URL, params, access_token=self.service.token
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
|
||||
def test_login_appservice_wrong_user(self) -> None:
|
||||
async def test_login_appservice_wrong_user(self) -> None:
|
||||
"""Test that non-as users cannot login with the as token"""
|
||||
self.register_appservice_user(AS_USER, self.service.token)
|
||||
await self.register_appservice_user(AS_USER, self.service.token)
|
||||
|
||||
params = {
|
||||
"type": login.LoginRestServlet.APPSERVICE_TYPE,
|
||||
"identifier": {"type": "m.id.user", "user": "fibble_wibble"},
|
||||
}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
b"POST", LOGIN_URL, params, access_token=self.service.token
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
|
||||
def test_login_appservice_wrong_as(self) -> None:
|
||||
async def test_login_appservice_wrong_as(self) -> None:
|
||||
"""Test that as users cannot login with wrong as token"""
|
||||
self.register_appservice_user(AS_USER, self.service.token)
|
||||
await self.register_appservice_user(AS_USER, self.service.token)
|
||||
|
||||
params = {
|
||||
"type": login.LoginRestServlet.APPSERVICE_TYPE,
|
||||
"identifier": {"type": "m.id.user", "user": AS_USER},
|
||||
}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
b"POST", LOGIN_URL, params, access_token=self.another_service.token
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
|
||||
def test_login_appservice_no_token(self) -> None:
|
||||
async def test_login_appservice_no_token(self) -> None:
|
||||
"""Test that users must provide a token when using the appservice
|
||||
login method
|
||||
"""
|
||||
self.register_appservice_user(AS_USER, self.service.token)
|
||||
await self.register_appservice_user(AS_USER, self.service.token)
|
||||
|
||||
params = {
|
||||
"type": login.LoginRestServlet.APPSERVICE_TYPE,
|
||||
"identifier": {"type": "m.id.user", "user": AS_USER},
|
||||
}
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
channel = await self.make_request(b"POST", LOGIN_URL, params)
|
||||
|
||||
self.assertEqual(channel.code, 401, msg=channel.result)
|
||||
|
||||
@@ -1616,10 +1616,10 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
account.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.http_client = Mock(spec=["get_file"])
|
||||
self.http_client.get_file.side_effect = mock_get_file
|
||||
hs = self.setup_test_homeserver(
|
||||
hs = await self.setup_test_homeserver(
|
||||
proxied_blocklisted_http_client=self.http_client
|
||||
)
|
||||
return hs
|
||||
@@ -1648,7 +1648,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
d.update(build_synapse_client_resource_tree(self.hs))
|
||||
return d
|
||||
|
||||
def proceed_to_username_picker_page(
|
||||
async def proceed_to_username_picker_page(
|
||||
self,
|
||||
fake_oidc_server: FakeOidcServer,
|
||||
displayname: str,
|
||||
@@ -1656,7 +1656,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
picture: str,
|
||||
) -> tuple[str, str]:
|
||||
# do the start of the login flow
|
||||
channel, _ = self.helper.auth_via_oidc(
|
||||
channel, _ = await self.helper.auth_via_oidc(
|
||||
fake_oidc_server,
|
||||
{
|
||||
"sub": "tester",
|
||||
@@ -1701,7 +1701,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
|
||||
return picker_url, session_id
|
||||
|
||||
def test_username_picker_use_displayname_avatar_and_email(self) -> None:
|
||||
async def test_username_picker_use_displayname_avatar_and_email(self) -> None:
|
||||
"""Test the happy path of a username picker flow with using displayname, avatar and email."""
|
||||
|
||||
fake_oidc_server = self.helper.fake_oidc_server()
|
||||
@@ -1711,7 +1711,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
email = "bobby@test.com"
|
||||
picture = "mxc://test/avatar_url"
|
||||
|
||||
picker_url, session_id = self.proceed_to_username_picker_page(
|
||||
picker_url, session_id = await self.proceed_to_username_picker_page(
|
||||
fake_oidc_server, displayname, email, picture
|
||||
)
|
||||
|
||||
@@ -1726,7 +1726,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
b"use_email": email,
|
||||
}
|
||||
).encode("utf8")
|
||||
chan = self.make_request(
|
||||
chan = await self.make_request(
|
||||
"POST",
|
||||
path=picker_url,
|
||||
content=content,
|
||||
@@ -1740,7 +1740,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
assert location_headers
|
||||
|
||||
# send a request to the completion page, which should 302 to the client redirectUrl
|
||||
chan = self.make_request(
|
||||
chan = await self.make_request(
|
||||
"GET",
|
||||
path=location_headers[0],
|
||||
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
|
||||
@@ -1765,7 +1765,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
|
||||
# finally, submit the matrix login token to the login API, which gives us our
|
||||
# matrix access token, mxid, and device id.
|
||||
chan = self.make_request(
|
||||
chan = await self.make_request(
|
||||
"POST",
|
||||
"/login",
|
||||
content={"type": "m.login.token", "token": login_token},
|
||||
@@ -1774,7 +1774,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
self.assertEqual(chan.json_body["user_id"], mxid)
|
||||
|
||||
# ensure the displayname and avatar from the OIDC response have been configured for the user.
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
@@ -1782,13 +1782,13 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
self.assertEqual(displayname, channel.json_body["displayname"])
|
||||
|
||||
# ensure the email from the OIDC response has been configured for the user.
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET", "/account/3pid", access_token=chan.json_body["access_token"]
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual(email, channel.json_body["threepids"][0]["address"])
|
||||
|
||||
def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None:
|
||||
async def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None:
|
||||
"""Test the happy path of a username picker flow without using displayname, avatar or email."""
|
||||
|
||||
fake_oidc_server = self.helper.fake_oidc_server()
|
||||
@@ -1799,7 +1799,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
picture = "mxc://test/avatar_url"
|
||||
username = "bobby"
|
||||
|
||||
picker_url, session_id = self.proceed_to_username_picker_page(
|
||||
picker_url, session_id = await self.proceed_to_username_picker_page(
|
||||
fake_oidc_server, displayname, email, picture
|
||||
)
|
||||
|
||||
@@ -1813,7 +1813,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
b"use_avatar": b"false",
|
||||
}
|
||||
).encode("utf8")
|
||||
chan = self.make_request(
|
||||
chan = await self.make_request(
|
||||
"POST",
|
||||
path=picker_url,
|
||||
content=content,
|
||||
@@ -1827,7 +1827,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
assert location_headers
|
||||
|
||||
# send a request to the completion page, which should 302 to the client redirectUrl
|
||||
chan = self.make_request(
|
||||
chan = await self.make_request(
|
||||
"GET",
|
||||
path=location_headers[0],
|
||||
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
|
||||
@@ -1852,7 +1852,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
|
||||
# finally, submit the matrix login token to the login API, which gives us our
|
||||
# matrix access token, mxid, and device id.
|
||||
chan = self.make_request(
|
||||
chan = await self.make_request(
|
||||
"POST",
|
||||
"/login",
|
||||
content={"type": "m.login.token", "token": login_token},
|
||||
@@ -1861,7 +1861,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
self.assertEqual(chan.json_body["user_id"], mxid)
|
||||
|
||||
# ensure the displayname and avatar from the OIDC response have not been configured for the user.
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
@@ -1869,7 +1869,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
self.assertEqual(username, channel.json_body["displayname"])
|
||||
|
||||
# ensure the email from the OIDC response has not been configured for the user.
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET", "/account/3pid", access_token=chan.json_body["access_token"]
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
@@ -81,7 +81,7 @@ class RestHelper:
|
||||
auth_user_id: str | None
|
||||
|
||||
@overload
|
||||
def create_room_as(
|
||||
async def create_room_as(
|
||||
self,
|
||||
room_creator: str | None = ...,
|
||||
is_public: bool | None = ...,
|
||||
@@ -93,7 +93,7 @@ class RestHelper:
|
||||
) -> str: ...
|
||||
|
||||
@overload
|
||||
def create_room_as(
|
||||
async def create_room_as(
|
||||
self,
|
||||
room_creator: str | None = ...,
|
||||
is_public: bool | None = ...,
|
||||
@@ -104,7 +104,7 @@ class RestHelper:
|
||||
custom_headers: Iterable[tuple[AnyStr, AnyStr]] | None = ...,
|
||||
) -> str | None: ...
|
||||
|
||||
def create_room_as(
|
||||
async def create_room_as(
|
||||
self,
|
||||
room_creator: str | None = None,
|
||||
is_public: bool | None = True,
|
||||
@@ -149,7 +149,7 @@ class RestHelper:
|
||||
if tok:
|
||||
path = path + "?access_token=%s" % tok
|
||||
|
||||
channel = make_request(
|
||||
channel = await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"POST",
|
||||
@@ -166,7 +166,7 @@ class RestHelper:
|
||||
else:
|
||||
return None
|
||||
|
||||
def invite(
|
||||
async def invite(
|
||||
self,
|
||||
room: str,
|
||||
src: str | None = None,
|
||||
@@ -175,7 +175,7 @@ class RestHelper:
|
||||
tok: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
) -> JsonDict:
|
||||
return self.change_membership(
|
||||
return await self.change_membership(
|
||||
room=room,
|
||||
src=src,
|
||||
targ=targ,
|
||||
@@ -185,7 +185,7 @@ class RestHelper:
|
||||
extra_data=extra_data,
|
||||
)
|
||||
|
||||
def join(
|
||||
async def join(
|
||||
self,
|
||||
room: str,
|
||||
user: str,
|
||||
@@ -195,7 +195,7 @@ class RestHelper:
|
||||
expect_errcode: Codes | None = None,
|
||||
expect_additional_fields: dict | None = None,
|
||||
) -> JsonDict:
|
||||
return self.change_membership(
|
||||
return await self.change_membership(
|
||||
room=room,
|
||||
src=user,
|
||||
targ=user,
|
||||
@@ -207,7 +207,7 @@ class RestHelper:
|
||||
expect_additional_fields=expect_additional_fields,
|
||||
)
|
||||
|
||||
def knock(
|
||||
async def knock(
|
||||
self,
|
||||
room: str | None = None,
|
||||
user: str | None = None,
|
||||
@@ -225,7 +225,7 @@ class RestHelper:
|
||||
if reason:
|
||||
data["reason"] = reason
|
||||
|
||||
channel = make_request(
|
||||
channel = await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"POST",
|
||||
@@ -241,14 +241,14 @@ class RestHelper:
|
||||
|
||||
self.auth_user_id = temp_id
|
||||
|
||||
def leave(
|
||||
async def leave(
|
||||
self,
|
||||
room: str,
|
||||
user: str | None = None,
|
||||
expect_code: int = HTTPStatus.OK,
|
||||
tok: str | None = None,
|
||||
) -> JsonDict:
|
||||
return self.change_membership(
|
||||
return await self.change_membership(
|
||||
room=room,
|
||||
src=user,
|
||||
targ=user,
|
||||
@@ -257,7 +257,7 @@ class RestHelper:
|
||||
expect_code=expect_code,
|
||||
)
|
||||
|
||||
def ban(
|
||||
async def ban(
|
||||
self,
|
||||
room: str,
|
||||
src: str,
|
||||
@@ -266,7 +266,7 @@ class RestHelper:
|
||||
tok: str | None = None,
|
||||
) -> JsonDict:
|
||||
"""A convenience helper: `change_membership` with `membership` preset to "ban"."""
|
||||
return self.change_membership(
|
||||
return await self.change_membership(
|
||||
room=room,
|
||||
src=src,
|
||||
targ=targ,
|
||||
@@ -275,7 +275,7 @@ class RestHelper:
|
||||
expect_code=expect_code,
|
||||
)
|
||||
|
||||
def change_membership(
|
||||
async def change_membership(
|
||||
self,
|
||||
room: str,
|
||||
src: str | None,
|
||||
@@ -325,7 +325,7 @@ class RestHelper:
|
||||
data = {"membership": membership}
|
||||
data.update(extra_data or {})
|
||||
|
||||
channel = make_request(
|
||||
channel = await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"PUT",
|
||||
@@ -372,7 +372,7 @@ class RestHelper:
|
||||
self.auth_user_id = temp_id
|
||||
return channel.json_body
|
||||
|
||||
def send(
|
||||
async def send(
|
||||
self,
|
||||
room_id: str,
|
||||
body: str | None = None,
|
||||
@@ -387,7 +387,7 @@ class RestHelper:
|
||||
|
||||
content = {"msgtype": "m.text", "body": body}
|
||||
|
||||
return self.send_event(
|
||||
return await self.send_event(
|
||||
room_id,
|
||||
type,
|
||||
content,
|
||||
@@ -397,7 +397,7 @@ class RestHelper:
|
||||
custom_headers=custom_headers,
|
||||
)
|
||||
|
||||
def send_messages(
|
||||
async def send_messages(
|
||||
self,
|
||||
room_id: str,
|
||||
num_events: int,
|
||||
@@ -413,7 +413,7 @@ class RestHelper:
|
||||
event_ids = []
|
||||
|
||||
for event_index in range(num_events):
|
||||
response = self.send_event(
|
||||
response = await self.send_event(
|
||||
room_id,
|
||||
EventTypes.Message,
|
||||
content_fn(event_index),
|
||||
@@ -423,7 +423,7 @@ class RestHelper:
|
||||
|
||||
return event_ids
|
||||
|
||||
def send_event(
|
||||
async def send_event(
|
||||
self,
|
||||
room_id: str,
|
||||
type: str,
|
||||
@@ -440,7 +440,7 @@ class RestHelper:
|
||||
if tok:
|
||||
path = path + "?access_token=%s" % tok
|
||||
|
||||
channel = make_request(
|
||||
channel = await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"PUT",
|
||||
@@ -457,7 +457,7 @@ class RestHelper:
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def send_sticky_event(
|
||||
async def send_sticky_event(
|
||||
self,
|
||||
room_id: str,
|
||||
type: str,
|
||||
@@ -480,7 +480,7 @@ class RestHelper:
|
||||
if tok:
|
||||
path = path + f"&access_token={tok}"
|
||||
|
||||
channel = make_request(
|
||||
channel = await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"PUT",
|
||||
@@ -495,7 +495,7 @@ class RestHelper:
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def get_event(
|
||||
async def get_event(
|
||||
self,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
@@ -517,7 +517,7 @@ class RestHelper:
|
||||
if tok:
|
||||
path = path + f"?access_token={tok}"
|
||||
|
||||
channel = make_request(
|
||||
channel = await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"GET",
|
||||
@@ -532,7 +532,7 @@ class RestHelper:
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def _read_write_state(
|
||||
async def _read_write_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
@@ -573,7 +573,7 @@ class RestHelper:
|
||||
if body is not None:
|
||||
content = json.dumps(body).encode("utf8")
|
||||
|
||||
channel = make_request(self.reactor, self.site, method, path, content)
|
||||
channel = await make_request(self.reactor, self.site, method, path, content)
|
||||
|
||||
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
|
||||
expect_code,
|
||||
@@ -583,7 +583,7 @@ class RestHelper:
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def get_state(
|
||||
async def get_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
@@ -606,11 +606,11 @@ class RestHelper:
|
||||
Raises:
|
||||
AssertionError: if expect_code doesn't match the HTTP code we received
|
||||
"""
|
||||
return self._read_write_state(
|
||||
return await self._read_write_state(
|
||||
room_id, event_type, None, tok, expect_code, state_key, method="GET"
|
||||
)
|
||||
|
||||
def send_state(
|
||||
async def send_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
@@ -635,11 +635,11 @@ class RestHelper:
|
||||
Raises:
|
||||
AssertionError: if expect_code doesn't match the HTTP code we received
|
||||
"""
|
||||
return self._read_write_state(
|
||||
return await self._read_write_state(
|
||||
room_id, event_type, body, tok, expect_code, state_key, method="PUT"
|
||||
)
|
||||
|
||||
def upload_media(
|
||||
async def upload_media(
|
||||
self,
|
||||
image_data: bytes,
|
||||
tok: str,
|
||||
@@ -655,7 +655,7 @@ class RestHelper:
|
||||
expect_code: The return code to expect from attempting to upload the media
|
||||
"""
|
||||
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
|
||||
channel = make_request(
|
||||
channel = await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"POST",
|
||||
@@ -672,7 +672,7 @@ class RestHelper:
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def whoami(
|
||||
async def whoami(
|
||||
self,
|
||||
access_token: str,
|
||||
expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK,
|
||||
@@ -684,7 +684,7 @@ class RestHelper:
|
||||
access_token: The user token to use during the request
|
||||
expect_code: The return code to expect from attempting the whoami request
|
||||
"""
|
||||
channel = make_request(
|
||||
channel = await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"GET",
|
||||
@@ -714,7 +714,7 @@ class RestHelper:
|
||||
issuer=issuer,
|
||||
)
|
||||
|
||||
def login_via_oidc(
|
||||
async def login_via_oidc(
|
||||
self,
|
||||
fake_server: FakeOidcServer,
|
||||
remote_user_id: str,
|
||||
@@ -735,7 +735,7 @@ class RestHelper:
|
||||
"""
|
||||
client_redirect_url = "https://x"
|
||||
userinfo = {"sub": remote_user_id}
|
||||
channel, grant = self.auth_via_oidc(
|
||||
channel, grant = await self.auth_via_oidc(
|
||||
fake_server,
|
||||
userinfo,
|
||||
client_redirect_url,
|
||||
@@ -754,9 +754,9 @@ class RestHelper:
|
||||
assert m, channel.text_body
|
||||
login_token = m.group(1)
|
||||
|
||||
return self.login_via_token(login_token, expected_status), grant
|
||||
return await self.login_via_token(login_token, expected_status), grant
|
||||
|
||||
def login_via_token(
|
||||
async def login_via_token(
|
||||
self,
|
||||
login_token: str,
|
||||
expected_status: int = 200,
|
||||
@@ -774,7 +774,7 @@ class RestHelper:
|
||||
the normal places.
|
||||
"""
|
||||
|
||||
channel = make_request(
|
||||
channel = await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"POST",
|
||||
@@ -786,7 +786,7 @@ class RestHelper:
|
||||
)
|
||||
return channel.json_body
|
||||
|
||||
def auth_via_oidc(
|
||||
async def auth_via_oidc(
|
||||
self,
|
||||
fake_server: FakeOidcServer,
|
||||
user_info_dict: JsonDict,
|
||||
@@ -834,10 +834,10 @@ class RestHelper:
|
||||
if ui_auth_session_id:
|
||||
# can't set the client redirect url for UI Auth
|
||||
assert client_redirect_url is None
|
||||
oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
|
||||
oauth_uri = await self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
|
||||
else:
|
||||
# otherwise, hit the login redirect endpoint
|
||||
oauth_uri = self.initiate_sso_login(
|
||||
oauth_uri = await self.initiate_sso_login(
|
||||
client_redirect_url, cookies, idp_id=idp_id
|
||||
)
|
||||
|
||||
@@ -850,11 +850,11 @@ class RestHelper:
|
||||
assert oauth_uri_path == fake_server.authorization_endpoint, (
|
||||
"unexpected SSO URI " + oauth_uri_path
|
||||
)
|
||||
return self.complete_oidc_auth(
|
||||
return await self.complete_oidc_auth(
|
||||
fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid
|
||||
)
|
||||
|
||||
def complete_oidc_auth(
|
||||
async def complete_oidc_auth(
|
||||
self,
|
||||
fake_serer: FakeOidcServer,
|
||||
oauth_uri: str,
|
||||
@@ -904,7 +904,7 @@ class RestHelper:
|
||||
|
||||
with fake_serer.patch_homeserver(hs=self.hs):
|
||||
# now hit the callback URI with the right params and a made-up code
|
||||
channel = make_request(
|
||||
channel = await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"GET",
|
||||
@@ -915,7 +915,7 @@ class RestHelper:
|
||||
)
|
||||
return channel, grant
|
||||
|
||||
def initiate_sso_login(
|
||||
async def initiate_sso_login(
|
||||
self,
|
||||
client_redirect_url: str | None,
|
||||
cookies: MutableMapping[str, str],
|
||||
@@ -948,7 +948,7 @@ class RestHelper:
|
||||
# hit the redirect url (which should redirect back to the redirect url. This
|
||||
# is the easiest way of figuring out what the Host header ought to be set to
|
||||
# to keep Synapse happy.
|
||||
channel = make_request(
|
||||
channel = await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"GET",
|
||||
@@ -967,7 +967,7 @@ class RestHelper:
|
||||
location = get_location(channel)
|
||||
parts = urllib.parse.urlsplit(location)
|
||||
next_uri = urllib.parse.urlunsplit(("", "") + parts[2:])
|
||||
channel = make_request(
|
||||
channel = await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"GET",
|
||||
@@ -981,7 +981,7 @@ class RestHelper:
|
||||
channel.extract_cookies(cookies)
|
||||
return get_location(channel)
|
||||
|
||||
def initiate_sso_ui_auth(
|
||||
async def initiate_sso_ui_auth(
|
||||
self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
|
||||
) -> str:
|
||||
"""Make a request to the ui-auth-via-sso endpoint, and return the target
|
||||
@@ -1001,7 +1001,7 @@ class RestHelper:
|
||||
+ urllib.parse.urlencode({"session": ui_auth_session_id})
|
||||
)
|
||||
# hit the redirect url (which will issue a cookie and state)
|
||||
channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
|
||||
channel = await make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
|
||||
# that should serve a confirmation page
|
||||
assert channel.code == HTTPStatus.OK, channel.text_body
|
||||
channel.extract_cookies(cookies)
|
||||
@@ -1014,9 +1014,9 @@ class RestHelper:
|
||||
oauth_uri = p.links[0]
|
||||
return oauth_uri
|
||||
|
||||
def send_read_receipt(self, room_id: str, event_id: str, *, tok: str) -> None:
|
||||
async def send_read_receipt(self, room_id: str, event_id: str, *, tok: str) -> None:
|
||||
"""Send a read receipt into the room at the given event"""
|
||||
channel = make_request(
|
||||
channel = await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
method="POST",
|
||||
|
||||
@@ -538,10 +538,26 @@ async def make_request(
|
||||
req.dispatch(root_resource)
|
||||
|
||||
if await_result and req.render_deferred is not None:
|
||||
# Await the handler task directly — the event loop drives
|
||||
# all concurrent tasks (DB operations, background processes)
|
||||
# naturally without needing nest_asyncio.
|
||||
await req.render_deferred
|
||||
import asyncio
|
||||
|
||||
# Advance fake time in a background task so that any
|
||||
# clock.sleep() calls in the handler (e.g., ratelimit pauses)
|
||||
# get resolved. We advance by 0.1s per tick.
|
||||
async def _advance_time() -> None:
|
||||
while not req.render_deferred.done():
|
||||
if clock is not None:
|
||||
clock.advance(0.1)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
advancer = asyncio.ensure_future(_advance_time())
|
||||
try:
|
||||
await req.render_deferred
|
||||
finally:
|
||||
advancer.cancel()
|
||||
try:
|
||||
await advancer
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return channel
|
||||
|
||||
|
||||
Reference in New Issue
Block a user