From 582be03d65e722e3e4e744ba99bcb7a4a9c6fa52 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Mon, 23 Mar 2026 10:36:06 +0000 Subject: [PATCH] convert login tests to asyncio and make them work with faketime --- tests/rest/client/test_login.py | 328 ++++++++++++++++---------------- tests/rest/client/utils.py | 108 +++++------ tests/server.py | 24 ++- 3 files changed, 238 insertions(+), 222 deletions(-) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index f6e2238992..45987c5086 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -206,11 +206,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_POST_ratelimiting_per_address(self) -> None: + async def test_POST_ratelimiting_per_address(self) -> None: # Create different users so we're sure not to be bothered by the per-user # ratelimiter. for i in range(6): - self.register_user("kermit" + str(i), "monkey") + await self.register_user("kermit" + str(i), "monkey") for i in range(6): params = { @@ -218,7 +218,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - channel = self.make_request(b"POST", LOGIN_URL, params) + channel = await self.make_request(b"POST", LOGIN_URL, params) if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) @@ -240,7 +240,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - channel = self.make_request(b"POST", LOGIN_URL, params) + channel = await self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.code, 200, msg=channel.result) @@ -257,8 +257,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_POST_ratelimiting_per_account(self) -> None: - self.register_user("kermit", "monkey") + async def test_POST_ratelimiting_per_account(self) -> None: + await self.register_user("kermit", "monkey") for i in range(6): params = { @@ -266,7 +266,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - channel = self.make_request(b"POST", LOGIN_URL, params) + channel = await self.make_request(b"POST", LOGIN_URL, params) if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) @@ -288,7 +288,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - channel = self.make_request(b"POST", LOGIN_URL, params) + channel = await self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.code, 200, msg=channel.result) @@ -305,8 +305,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: - self.register_user("kermit", "monkey") + async def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: + await self.register_user("kermit", "monkey") for i in range(6): params = { @@ -314,7 +314,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - channel = self.make_request(b"POST", LOGIN_URL, params) + channel = await self.make_request(b"POST", LOGIN_URL, params) if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) @@ -336,7 +336,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - channel = self.make_request(b"POST", LOGIN_URL, params) + channel = await self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.code, 403, msg=channel.result) @@ -430,57 +430,57 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: - self.register_user("kermit", "monkey") + async def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: + await self.register_user("kermit", "monkey") # log in as normal - access_token = self.login("kermit", "monkey") + access_token = await self.login("kermit", "monkey") # we should now be able to make requests with the access token - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + channel = await self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + channel = await self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) # Now try to hard logout this session - channel = self.make_request(b"POST", "/logout", access_token=access_token) + channel = await self.make_request(b"POST", "/logout", access_token=access_token) self.assertEqual(channel.code, 200, msg=channel.result) @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( + async def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( self, ) -> None: - self.register_user("kermit", "monkey") + await self.register_user("kermit", "monkey") # log in as normal - access_token = self.login("kermit", "monkey") + access_token = await self.login("kermit", "monkey") # we should now be able to make requests with the access token - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + channel = await self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + channel = await self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) # Now try to hard log out all of the user's sessions - channel = self.make_request(b"POST", "/logout/all", access_token=access_token) + channel = await self.make_request(b"POST", "/logout/all", access_token=access_token) self.assertEqual(channel.code, 200, msg=channel.result) - def test_login_with_overly_long_device_id_fails(self) -> None: - self.register_user("mickey", "cheese") + async def test_login_with_overly_long_device_id_fails(self) -> None: + await self.register_user("mickey", "cheese") # create a device_id longer than 512 characters device_id = "yolo" * 512 @@ -493,7 +493,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } # make a login request with the bad device_id - channel = self.make_request( + channel = await self.make_request( "POST", "/_matrix/client/v3/login", body, @@ -514,8 +514,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } } ) - def test_require_approval(self) -> None: - channel = self.make_request( + async def test_require_approval(self) -> None: + channel = await self.make_request( "POST", "register", { @@ -535,25 +535,25 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - channel = self.make_request("POST", LOGIN_URL, params) + channel = await self.make_request("POST", LOGIN_URL, params) self.assertEqual(403, channel.code, channel.result) self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) self.assertEqual( ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] ) - def test_get_login_flows_with_login_via_existing_disabled(self) -> None: + async def test_get_login_flows_with_login_via_existing_disabled(self) -> None: """GET /login should return m.login.token without get_login_token""" - channel = self.make_request("GET", "/_matrix/client/r0/login") + channel = await self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) flows = {flow["type"]: flow for flow in channel.json_body["flows"]} self.assertNotIn("m.login.token", flows) @override_config({"login_via_existing_session": {"enabled": True}}) - def test_get_login_flows_with_login_via_existing_enabled(self) -> None: + async def test_get_login_flows_with_login_via_existing_enabled(self) -> None: """GET /login should return m.login.token with get_login_token true""" - channel = self.make_request("GET", "/_matrix/client/r0/login") + channel = await self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) self.assertCountEqual( @@ -576,13 +576,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ] } ) - def test_spam_checker_allow(self) -> None: + async def test_spam_checker_allow(self) -> None: """Check that that adding a spam checker doesn't break login.""" - self.register_user("kermit", "monkey") + await self.register_user("kermit", "monkey") body = {"type": "m.login.password", "user": "kermit", "password": "monkey"} - channel = self.make_request( + channel = await self.make_request( "POST", "/_matrix/client/r0/login", body, @@ -600,14 +600,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ] } ) - def test_spam_checker_deny(self) -> None: + async def test_spam_checker_deny(self) -> None: """Check that login""" - self.register_user("kermit", "monkey") + await self.register_user("kermit", "monkey") body = {"type": "m.login.password", "user": "kermit", "password": "monkey"} - channel = self.make_request( + channel = await self.make_request( "POST", "/_matrix/client/r0/login", body, @@ -677,9 +677,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_get_login_flows(self) -> None: + async def test_get_login_flows(self) -> None: """GET /login should return password and SSO flows""" - channel = self.make_request("GET", "/_matrix/client/r0/login") + channel = await self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) expected_flow_types = [ @@ -704,17 +704,17 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ], ) - def test_multi_sso_redirect(self) -> None: + async def test_multi_sso_redirect(self) -> None: """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker - channel = self._make_sso_redirect_request(None) + channel = await self._make_sso_redirect_request(None) self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers uri = location_headers[0] # hitting that picker should give us some HTML - channel = self.make_request("GET", uri) + channel = await self.make_request("GET", uri) self.assertEqual(channel.code, 200, channel.result) # parse the form to check it has fields assumed elsewhere in this class @@ -734,10 +734,10 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) - def test_multi_sso_redirect_to_cas(self) -> None: + async def test_multi_sso_redirect_to_cas(self) -> None: """If CAS is chosen, should redirect to the CAS server""" - channel = self.make_request( + channel = await self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) @@ -758,7 +758,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) # follow the redirect - channel = self.make_request( + channel = await self.make_request( "GET", # We have to make this relative to be compatible with `make_request(...)` get_relative_uri_from_absolute_uri(sso_login_redirect_uri), @@ -787,9 +787,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): service_uri_params = urllib.parse.parse_qs(service_uri_query) self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) - def test_multi_sso_redirect_to_saml(self) -> None: + async def test_multi_sso_redirect_to_saml(self) -> None: """If SAML is chosen, should redirect to the SAML server""" - channel = self.make_request( + channel = await self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) @@ -809,7 +809,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) # follow the redirect - channel = self.make_request( + channel = await self.make_request( "GET", # We have to make this relative to be compatible with `make_request(...)` get_relative_uri_from_absolute_uri(sso_login_redirect_uri), @@ -835,14 +835,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): relay_state_param = saml_uri_params["RelayState"][0] self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - def test_login_via_oidc(self) -> None: + async def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" fake_oidc_server = self.helper.fake_oidc_server() with fake_oidc_server.patch_homeserver(hs=self.hs): # pick the default OIDC provider - channel = self.make_request( + channel = await self.make_request( "GET", f"/_synapse/client/pick_idp?redirectUrl={urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)}&idp=oidc", ) @@ -861,7 +861,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): with fake_oidc_server.patch_homeserver(hs=self.hs): # follow the redirect - channel = self.make_request( + channel = await self.make_request( "GET", # We have to make this relative to be compatible with `make_request(...)` get_relative_uri_from_absolute_uri(sso_login_redirect_uri), @@ -897,7 +897,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): TEST_CLIENT_REDIRECT_URL, ) - channel, _ = self.helper.complete_oidc_auth( + channel, _ = await self.helper.complete_oidc_auth( fake_oidc_server, oidc_uri, cookies, {"sub": "user1"} ) @@ -925,7 +925,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # finally, submit the matrix login token to the login API, which gives us our # matrix access token, mxid, and device id. login_token = params[2][1] - chan = self.make_request( + chan = await self.make_request( "POST", "/login", content={"type": "m.login.token", "token": login_token}, @@ -933,15 +933,15 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@user1:test") - def test_multi_sso_redirect_unknown_idp(self) -> None: + async def test_multi_sso_redirect_unknown_idp(self) -> None: """An unknown IdP should cause a 400 bad request error""" - channel = self.make_request( + channel = await self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", ) self.assertEqual(channel.code, 400, channel.result) - def test_multi_sso_redirect_unknown_idp_as_url(self) -> None: + async def test_multi_sso_redirect_unknown_idp_as_url(self) -> None: """ An unknown IdP that looks like a URL should cause a 400 bad request error (to avoid open redirects). @@ -954,24 +954,24 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): in the URL building tests to cover this case but is only a unit test vs something at the REST layer here that covers things end-to-end. """ - channel = self.make_request( + channel = await self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=something&idp=https://element.io/", ) self.assertEqual(channel.code, 400, channel.result) - def test_client_idp_redirect_to_unknown(self) -> None: + async def test_client_idp_redirect_to_unknown(self) -> None: """If the client tries to pick an unknown IdP, return a 404""" - channel = self._make_sso_redirect_request("xxx") + channel = await self._make_sso_redirect_request("xxx") self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - def test_client_idp_redirect_to_oidc(self) -> None: + async def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" fake_oidc_server = self.helper.fake_oidc_server() with fake_oidc_server.patch_homeserver(hs=self.hs): - channel = self._make_sso_redirect_request("oidc") + channel = await self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -981,7 +981,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) - def _make_sso_redirect_request(self, idp_prov: str | None = None) -> FakeChannel: + async def _make_sso_redirect_request(self, idp_prov: str | None = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect ... possibly specifying an IDP provider @@ -991,7 +991,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): endpoint += "/" + idp_prov endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - return self.make_request( + return await self.make_request( "GET", endpoint, custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)], @@ -1011,7 +1011,7 @@ class CASTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.base_url = "https://matrix.goodserver.com/" self.redirect_path = "_synapse/client/login/sso/redirect/confirm" @@ -1054,7 +1054,7 @@ class CASTestCase(unittest.HomeserverTestCase): mocked_http_client = Mock(spec=["get_raw"]) mocked_http_client.get_raw.side_effect = get_raw - self.hs = self.setup_test_homeserver( + self.hs = await self.setup_test_homeserver( config=config, proxied_http_client=mocked_http_client, ) @@ -1064,7 +1064,7 @@ class CASTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.deactivate_account_handler = hs.get_deactivate_account_handler() - def test_cas_redirect_confirm(self) -> None: + async def test_cas_redirect_confirm(self) -> None: """Tests that the SSO login flow serves a confirmation page before redirecting a user to the redirect URL. """ @@ -1079,7 +1079,7 @@ class CASTestCase(unittest.HomeserverTestCase): cas_ticket_url = urllib.parse.urlunparse(url_parts) # Get Synapse to call the fake CAS and serve the template. - channel = self.make_request("GET", cas_ticket_url) + channel = await self.make_request("GET", cas_ticket_url) # Test that the response is HTML. self.assertEqual(channel.code, 200, channel.result) @@ -1105,15 +1105,15 @@ class CASTestCase(unittest.HomeserverTestCase): } } ) - def test_cas_redirect_whitelisted(self) -> None: + async def test_cas_redirect_whitelisted(self) -> None: """Tests that the SSO login flow serves a redirect to a whitelisted url""" - self._test_redirect("https://legit-site.com/") + await self._test_redirect("https://legit-site.com/") @override_config({"public_baseurl": "https://example.com"}) - def test_cas_redirect_login_fallback(self) -> None: - self._test_redirect("https://example.com/_matrix/static/client/login") + async def test_cas_redirect_login_fallback(self) -> None: + await self._test_redirect("https://example.com/_matrix/static/client/login") - def _test_redirect(self, redirect_url: str) -> None: + async def _test_redirect(self, redirect_url: str) -> None: """Tests that the SSO login flow serves a redirect for the given redirect URL.""" cas_ticket_url = ( "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" @@ -1121,7 +1121,7 @@ class CASTestCase(unittest.HomeserverTestCase): ) # Get Synapse to call the fake CAS and serve the template. - channel = self.make_request("GET", cas_ticket_url) + channel = await self.make_request("GET", cas_ticket_url) self.assertEqual(channel.code, 302) location_headers = channel.headers.getRawHeaders("Location") @@ -1129,15 +1129,15 @@ class CASTestCase(unittest.HomeserverTestCase): self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) - def test_deactivated_user(self) -> None: + async def test_deactivated_user(self) -> None: """Logging in as a deactivated account should error.""" redirect_url = "https://legit-site.com/" # First login (to create the user). - self._test_redirect(redirect_url) + await self._test_redirect(redirect_url) # Deactivate the account. - self.get_success( + await self.get_success( self.deactivate_account_handler.deactivate_account( self.user_id, False, create_requester(self.user_id) ) @@ -1150,7 +1150,7 @@ class CASTestCase(unittest.HomeserverTestCase): ) # Get Synapse to call the fake CAS and serve the template. - channel = self.make_request("GET", cas_ticket_url) + channel = await self.make_request("GET", cas_ticket_url) # Because the user is deactivated they are served an error template. self.assertEqual(channel.code, 403) @@ -1187,24 +1187,24 @@ class JWTTestCase(unittest.HomeserverTestCase): result: bytes = jwt.encode(header, payload, secret) return result.decode("ascii") - def jwt_login(self, *args: Any) -> FakeChannel: + async def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - channel = self.make_request(b"POST", LOGIN_URL, params) + channel = await self.make_request(b"POST", LOGIN_URL, params) return channel - def test_login_jwt_valid_registered(self) -> None: - self.register_user("kermit", "monkey") - channel = self.jwt_login({"sub": "kermit"}) + async def test_login_jwt_valid_registered(self) -> None: + await self.register_user("kermit", "monkey") + channel = await self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - def test_login_jwt_valid_unregistered(self) -> None: - channel = self.jwt_login({"sub": "frog"}) + async def test_login_jwt_valid_unregistered(self) -> None: + channel = await self.jwt_login({"sub": "frog"}) self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") - def test_login_jwt_invalid_signature(self) -> None: - channel = self.jwt_login({"sub": "frog"}, "notsecret") + async def test_login_jwt_invalid_signature(self) -> None: + channel = await self.jwt_login({"sub": "frog"}, "notsecret") self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( @@ -1212,8 +1212,8 @@ class JWTTestCase(unittest.HomeserverTestCase): "JWT validation failed: Signature verification failed", ) - def test_login_jwt_expired(self) -> None: - channel = self.jwt_login({"sub": "frog", "exp": 864000}) + async def test_login_jwt_expired(self) -> None: + channel = await self.jwt_login({"sub": "frog", "exp": 864000}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( @@ -1221,9 +1221,9 @@ class JWTTestCase(unittest.HomeserverTestCase): "JWT validation failed: expired_token: The token is expired", ) - def test_login_jwt_not_before(self) -> None: + async def test_login_jwt_not_before(self) -> None: now = int(time.time()) - channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) + channel = await self.jwt_login({"sub": "frog", "nbf": now + 3600}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( @@ -1231,22 +1231,22 @@ class JWTTestCase(unittest.HomeserverTestCase): "JWT validation failed: invalid_token: The token is not valid yet", ) - def test_login_no_sub(self) -> None: - channel = self.jwt_login({"username": "root"}) + async def test_login_no_sub(self) -> None: + channel = await self.jwt_login({"username": "root"}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}}) - def test_login_iss(self) -> None: + async def test_login_iss(self) -> None: """Test validating the issuer claim.""" # A valid issuer. - channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) + channel = await self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") # An invalid issuer. - channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + channel = await self.jwt_login({"sub": "kermit", "iss": "invalid"}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertRegex( @@ -1255,7 +1255,7 @@ class JWTTestCase(unittest.HomeserverTestCase): ) # Not providing an issuer. - channel = self.jwt_login({"sub": "kermit"}) + channel = await self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertRegex( @@ -1263,22 +1263,22 @@ class JWTTestCase(unittest.HomeserverTestCase): r"^JWT validation failed: missing_claim: Missing [\"']iss[\"'] claim$", ) - def test_login_iss_no_config(self) -> None: + async def test_login_iss_no_config(self) -> None: """Test providing an issuer claim without requiring it in the configuration.""" - channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + channel = await self.jwt_login({"sub": "kermit", "iss": "invalid"}) self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) - def test_login_aud(self) -> None: + async def test_login_aud(self) -> None: """Test validating the audience claim.""" # A valid audience. - channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) + channel = await self.jwt_login({"sub": "kermit", "aud": "test-audience"}) self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") # An invalid audience. - channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + channel = await self.jwt_login({"sub": "kermit", "aud": "invalid"}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertRegex( @@ -1287,7 +1287,7 @@ class JWTTestCase(unittest.HomeserverTestCase): ) # Not providing an audience. - channel = self.jwt_login({"sub": "kermit"}) + channel = await self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertRegex( @@ -1295,9 +1295,9 @@ class JWTTestCase(unittest.HomeserverTestCase): r"^JWT validation failed: missing_claim: Missing [\"']aud[\"'] claim$", ) - def test_login_aud_no_config(self) -> None: + async def test_login_aud_no_config(self) -> None: """Test providing an audience without requiring it in the configuration.""" - channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + channel = await self.jwt_login({"sub": "kermit", "aud": "invalid"}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertRegex( @@ -1305,36 +1305,36 @@ class JWTTestCase(unittest.HomeserverTestCase): r"^JWT validation failed: invalid_claim: Invalid claim [\"']aud[\"']$", ) - def test_login_default_sub(self) -> None: + async def test_login_default_sub(self) -> None: """Test reading user ID from the default subject claim.""" - channel = self.jwt_login({"sub": "kermit"}) + channel = await self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) - def test_login_custom_sub(self) -> None: + async def test_login_custom_sub(self) -> None: """Test reading user ID from a custom subject claim.""" - channel = self.jwt_login({"username": "frog"}) + channel = await self.jwt_login({"username": "frog"}) self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") @override_config( {"jwt_config": {**base_config, "display_name_claim": "display_name"}} ) - def test_login_custom_display_name(self) -> None: + async def test_login_custom_display_name(self) -> None: """Test setting a custom display name.""" localpart = "pinkie" user_id = f"@{localpart}:test" display_name = "Pinkie Pie" # Perform the login, specifying a custom display name. - channel = self.jwt_login({"sub": localpart, "display_name": display_name}) + channel = await self.jwt_login({"sub": localpart, "display_name": display_name}) self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], user_id) # Fetch the user's display name and check that it was set correctly. access_token = channel.json_body["access_token"] - channel = self.make_request( + channel = await self.make_request( "GET", f"/_matrix/client/v3/profile/{user_id}/displayname", access_token=access_token, @@ -1342,23 +1342,23 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["displayname"], display_name) - def test_login_no_token(self) -> None: + async def test_login_no_token(self) -> None: params = {"type": "org.matrix.login.jwt"} - channel = self.make_request(b"POST", LOGIN_URL, params) + channel = await self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") - def test_deactivated_user(self) -> None: + async def test_deactivated_user(self) -> None: """Logging in as a deactivated account should error.""" - user_id = self.register_user("kermit", "monkey") - self.get_success( + user_id = await self.register_user("kermit", "monkey") + await self.get_success( self.hs.get_deactivate_account_handler().deactivate_account( user_id, erase_data=False, requester=create_requester(user_id) ) ) - channel = self.jwt_login({"sub": "kermit"}) + channel = await self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_USER_DEACTIVATED") self.assertEqual( @@ -1436,18 +1436,18 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): result: bytes = jwt.encode(header, payload, secret) return result.decode("ascii") - def jwt_login(self, *args: Any) -> FakeChannel: + async def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - channel = self.make_request(b"POST", LOGIN_URL, params) + channel = await self.make_request(b"POST", LOGIN_URL, params) return channel - def test_login_jwt_valid(self) -> None: - channel = self.jwt_login({"sub": "kermit"}) + async def test_login_jwt_valid(self) -> None: + channel = await self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - def test_login_jwt_invalid_signature(self) -> None: - channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) + async def test_login_jwt_invalid_signature(self) -> None: + channel = await self.jwt_login({"sub": "frog"}, self.bad_privatekey) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( @@ -1465,8 +1465,8 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.hs = self.setup_test_homeserver() + async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.hs = await self.setup_test_homeserver() self.service = ApplicationService( id="unique_identifier", @@ -1511,23 +1511,23 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.services_cache.append(self.msc4190_service) return self.hs - def test_login_appservice_user(self) -> None: + async def test_login_appservice_user(self) -> None: """Test that an appservice user can use /login""" - self.register_appservice_user(AS_USER, self.service.token) + await self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": AS_USER}, } - channel = self.make_request( + channel = await self.make_request( b"POST", LOGIN_URL, params, access_token=self.service.token ) self.assertEqual(channel.code, 200, msg=channel.result) - def test_login_appservice_msc4190_fail(self) -> None: + async def test_login_appservice_msc4190_fail(self) -> None: """Test that an appservice user can use /login""" - self.register_appservice_user( + await self.register_appservice_user( "as3_user_alice", self.msc4190_service.token, inhibit_login=True ) @@ -1535,7 +1535,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": "as3_user_alice"}, } - channel = self.make_request( + channel = await self.make_request( b"POST", LOGIN_URL, params, access_token=self.msc4190_service.token ) @@ -1546,9 +1546,9 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): channel.json_body, ) - def test_login_appservice_user_bot(self) -> None: + async def test_login_appservice_user_bot(self) -> None: """Test that the appservice bot can use /login""" - self.register_appservice_user(AS_USER, self.service.token) + await self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, @@ -1557,51 +1557,51 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "user": self.service.sender.to_string(), }, } - channel = self.make_request( + channel = await self.make_request( b"POST", LOGIN_URL, params, access_token=self.service.token ) self.assertEqual(channel.code, 200, msg=channel.result) - def test_login_appservice_wrong_user(self) -> None: + async def test_login_appservice_wrong_user(self) -> None: """Test that non-as users cannot login with the as token""" - self.register_appservice_user(AS_USER, self.service.token) + await self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": "fibble_wibble"}, } - channel = self.make_request( + channel = await self.make_request( b"POST", LOGIN_URL, params, access_token=self.service.token ) self.assertEqual(channel.code, 403, msg=channel.result) - def test_login_appservice_wrong_as(self) -> None: + async def test_login_appservice_wrong_as(self) -> None: """Test that as users cannot login with wrong as token""" - self.register_appservice_user(AS_USER, self.service.token) + await self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": AS_USER}, } - channel = self.make_request( + channel = await self.make_request( b"POST", LOGIN_URL, params, access_token=self.another_service.token ) self.assertEqual(channel.code, 403, msg=channel.result) - def test_login_appservice_no_token(self) -> None: + async def test_login_appservice_no_token(self) -> None: """Test that users must provide a token when using the appservice login method """ - self.register_appservice_user(AS_USER, self.service.token) + await self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": AS_USER}, } - channel = self.make_request(b"POST", LOGIN_URL, params) + channel = await self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.code, 401, msg=channel.result) @@ -1616,10 +1616,10 @@ class UsernamePickerTestCase(HomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.http_client = Mock(spec=["get_file"]) self.http_client.get_file.side_effect = mock_get_file - hs = self.setup_test_homeserver( + hs = await self.setup_test_homeserver( proxied_blocklisted_http_client=self.http_client ) return hs @@ -1648,7 +1648,7 @@ class UsernamePickerTestCase(HomeserverTestCase): d.update(build_synapse_client_resource_tree(self.hs)) return d - def proceed_to_username_picker_page( + async def proceed_to_username_picker_page( self, fake_oidc_server: FakeOidcServer, displayname: str, @@ -1656,7 +1656,7 @@ class UsernamePickerTestCase(HomeserverTestCase): picture: str, ) -> tuple[str, str]: # do the start of the login flow - channel, _ = self.helper.auth_via_oidc( + channel, _ = await self.helper.auth_via_oidc( fake_oidc_server, { "sub": "tester", @@ -1701,7 +1701,7 @@ class UsernamePickerTestCase(HomeserverTestCase): return picker_url, session_id - def test_username_picker_use_displayname_avatar_and_email(self) -> None: + async def test_username_picker_use_displayname_avatar_and_email(self) -> None: """Test the happy path of a username picker flow with using displayname, avatar and email.""" fake_oidc_server = self.helper.fake_oidc_server() @@ -1711,7 +1711,7 @@ class UsernamePickerTestCase(HomeserverTestCase): email = "bobby@test.com" picture = "mxc://test/avatar_url" - picker_url, session_id = self.proceed_to_username_picker_page( + picker_url, session_id = await self.proceed_to_username_picker_page( fake_oidc_server, displayname, email, picture ) @@ -1726,7 +1726,7 @@ class UsernamePickerTestCase(HomeserverTestCase): b"use_email": email, } ).encode("utf8") - chan = self.make_request( + chan = await self.make_request( "POST", path=picker_url, content=content, @@ -1740,7 +1740,7 @@ class UsernamePickerTestCase(HomeserverTestCase): assert location_headers # send a request to the completion page, which should 302 to the client redirectUrl - chan = self.make_request( + chan = await self.make_request( "GET", path=location_headers[0], custom_headers=[("Cookie", "username_mapping_session=" + session_id)], @@ -1765,7 +1765,7 @@ class UsernamePickerTestCase(HomeserverTestCase): # finally, submit the matrix login token to the login API, which gives us our # matrix access token, mxid, and device id. - chan = self.make_request( + chan = await self.make_request( "POST", "/login", content={"type": "m.login.token", "token": login_token}, @@ -1774,7 +1774,7 @@ class UsernamePickerTestCase(HomeserverTestCase): self.assertEqual(chan.json_body["user_id"], mxid) # ensure the displayname and avatar from the OIDC response have been configured for the user. - channel = self.make_request( + channel = await self.make_request( "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"] ) self.assertEqual(channel.code, 200, channel.result) @@ -1782,13 +1782,13 @@ class UsernamePickerTestCase(HomeserverTestCase): self.assertEqual(displayname, channel.json_body["displayname"]) # ensure the email from the OIDC response has been configured for the user. - channel = self.make_request( + channel = await self.make_request( "GET", "/account/3pid", access_token=chan.json_body["access_token"] ) self.assertEqual(channel.code, 200, channel.result) self.assertEqual(email, channel.json_body["threepids"][0]["address"]) - def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None: + async def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None: """Test the happy path of a username picker flow without using displayname, avatar or email.""" fake_oidc_server = self.helper.fake_oidc_server() @@ -1799,7 +1799,7 @@ class UsernamePickerTestCase(HomeserverTestCase): picture = "mxc://test/avatar_url" username = "bobby" - picker_url, session_id = self.proceed_to_username_picker_page( + picker_url, session_id = await self.proceed_to_username_picker_page( fake_oidc_server, displayname, email, picture ) @@ -1813,7 +1813,7 @@ class UsernamePickerTestCase(HomeserverTestCase): b"use_avatar": b"false", } ).encode("utf8") - chan = self.make_request( + chan = await self.make_request( "POST", path=picker_url, content=content, @@ -1827,7 +1827,7 @@ class UsernamePickerTestCase(HomeserverTestCase): assert location_headers # send a request to the completion page, which should 302 to the client redirectUrl - chan = self.make_request( + chan = await self.make_request( "GET", path=location_headers[0], custom_headers=[("Cookie", "username_mapping_session=" + session_id)], @@ -1852,7 +1852,7 @@ class UsernamePickerTestCase(HomeserverTestCase): # finally, submit the matrix login token to the login API, which gives us our # matrix access token, mxid, and device id. - chan = self.make_request( + chan = await self.make_request( "POST", "/login", content={"type": "m.login.token", "token": login_token}, @@ -1861,7 +1861,7 @@ class UsernamePickerTestCase(HomeserverTestCase): self.assertEqual(chan.json_body["user_id"], mxid) # ensure the displayname and avatar from the OIDC response have not been configured for the user. - channel = self.make_request( + channel = await self.make_request( "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"] ) self.assertEqual(channel.code, 200, channel.result) @@ -1869,7 +1869,7 @@ class UsernamePickerTestCase(HomeserverTestCase): self.assertEqual(username, channel.json_body["displayname"]) # ensure the email from the OIDC response has not been configured for the user. - channel = self.make_request( + channel = await self.make_request( "GET", "/account/3pid", access_token=chan.json_body["access_token"] ) self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 81721c7dcf..391177090f 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -81,7 +81,7 @@ class RestHelper: auth_user_id: str | None @overload - def create_room_as( + async def create_room_as( self, room_creator: str | None = ..., is_public: bool | None = ..., @@ -93,7 +93,7 @@ class RestHelper: ) -> str: ... @overload - def create_room_as( + async def create_room_as( self, room_creator: str | None = ..., is_public: bool | None = ..., @@ -104,7 +104,7 @@ class RestHelper: custom_headers: Iterable[tuple[AnyStr, AnyStr]] | None = ..., ) -> str | None: ... - def create_room_as( + async def create_room_as( self, room_creator: str | None = None, is_public: bool | None = True, @@ -149,7 +149,7 @@ class RestHelper: if tok: path = path + "?access_token=%s" % tok - channel = make_request( + channel = await make_request( self.reactor, self.site, "POST", @@ -166,7 +166,7 @@ class RestHelper: else: return None - def invite( + async def invite( self, room: str, src: str | None = None, @@ -175,7 +175,7 @@ class RestHelper: tok: str | None = None, extra_data: dict | None = None, ) -> JsonDict: - return self.change_membership( + return await self.change_membership( room=room, src=src, targ=targ, @@ -185,7 +185,7 @@ class RestHelper: extra_data=extra_data, ) - def join( + async def join( self, room: str, user: str, @@ -195,7 +195,7 @@ class RestHelper: expect_errcode: Codes | None = None, expect_additional_fields: dict | None = None, ) -> JsonDict: - return self.change_membership( + return await self.change_membership( room=room, src=user, targ=user, @@ -207,7 +207,7 @@ class RestHelper: expect_additional_fields=expect_additional_fields, ) - def knock( + async def knock( self, room: str | None = None, user: str | None = None, @@ -225,7 +225,7 @@ class RestHelper: if reason: data["reason"] = reason - channel = make_request( + channel = await make_request( self.reactor, self.site, "POST", @@ -241,14 +241,14 @@ class RestHelper: self.auth_user_id = temp_id - def leave( + async def leave( self, room: str, user: str | None = None, expect_code: int = HTTPStatus.OK, tok: str | None = None, ) -> JsonDict: - return self.change_membership( + return await self.change_membership( room=room, src=user, targ=user, @@ -257,7 +257,7 @@ class RestHelper: expect_code=expect_code, ) - def ban( + async def ban( self, room: str, src: str, @@ -266,7 +266,7 @@ class RestHelper: tok: str | None = None, ) -> JsonDict: """A convenience helper: `change_membership` with `membership` preset to "ban".""" - return self.change_membership( + return await self.change_membership( room=room, src=src, targ=targ, @@ -275,7 +275,7 @@ class RestHelper: expect_code=expect_code, ) - def change_membership( + async def change_membership( self, room: str, src: str | None, @@ -325,7 +325,7 @@ class RestHelper: data = {"membership": membership} data.update(extra_data or {}) - channel = make_request( + channel = await make_request( self.reactor, self.site, "PUT", @@ -372,7 +372,7 @@ class RestHelper: self.auth_user_id = temp_id return channel.json_body - def send( + async def send( self, room_id: str, body: str | None = None, @@ -387,7 +387,7 @@ class RestHelper: content = {"msgtype": "m.text", "body": body} - return self.send_event( + return await self.send_event( room_id, type, content, @@ -397,7 +397,7 @@ class RestHelper: custom_headers=custom_headers, ) - def send_messages( + async def send_messages( self, room_id: str, num_events: int, @@ -413,7 +413,7 @@ class RestHelper: event_ids = [] for event_index in range(num_events): - response = self.send_event( + response = await self.send_event( room_id, EventTypes.Message, content_fn(event_index), @@ -423,7 +423,7 @@ class RestHelper: return event_ids - def send_event( + async def send_event( self, room_id: str, type: str, @@ -440,7 +440,7 @@ class RestHelper: if tok: path = path + "?access_token=%s" % tok - channel = make_request( + channel = await make_request( self.reactor, self.site, "PUT", @@ -457,7 +457,7 @@ class RestHelper: return channel.json_body - def send_sticky_event( + async def send_sticky_event( self, room_id: str, type: str, @@ -480,7 +480,7 @@ class RestHelper: if tok: path = path + f"&access_token={tok}" - channel = make_request( + channel = await make_request( self.reactor, self.site, "PUT", @@ -495,7 +495,7 @@ class RestHelper: return channel.json_body - def get_event( + async def get_event( self, room_id: str, event_id: str, @@ -517,7 +517,7 @@ class RestHelper: if tok: path = path + f"?access_token={tok}" - channel = make_request( + channel = await make_request( self.reactor, self.site, "GET", @@ -532,7 +532,7 @@ class RestHelper: return channel.json_body - def _read_write_state( + async def _read_write_state( self, room_id: str, event_type: str, @@ -573,7 +573,7 @@ class RestHelper: if body is not None: content = json.dumps(body).encode("utf8") - channel = make_request(self.reactor, self.site, method, path, content) + channel = await make_request(self.reactor, self.site, method, path, content) assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, @@ -583,7 +583,7 @@ class RestHelper: return channel.json_body - def get_state( + async def get_state( self, room_id: str, event_type: str, @@ -606,11 +606,11 @@ class RestHelper: Raises: AssertionError: if expect_code doesn't match the HTTP code we received """ - return self._read_write_state( + return await self._read_write_state( room_id, event_type, None, tok, expect_code, state_key, method="GET" ) - def send_state( + async def send_state( self, room_id: str, event_type: str, @@ -635,11 +635,11 @@ class RestHelper: Raises: AssertionError: if expect_code doesn't match the HTTP code we received """ - return self._read_write_state( + return await self._read_write_state( room_id, event_type, body, tok, expect_code, state_key, method="PUT" ) - def upload_media( + async def upload_media( self, image_data: bytes, tok: str, @@ -655,7 +655,7 @@ class RestHelper: expect_code: The return code to expect from attempting to upload the media """ path = "/_matrix/media/r0/upload?filename=%s" % (filename,) - channel = make_request( + channel = await make_request( self.reactor, self.site, "POST", @@ -672,7 +672,7 @@ class RestHelper: return channel.json_body - def whoami( + async def whoami( self, access_token: str, expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK, @@ -684,7 +684,7 @@ class RestHelper: access_token: The user token to use during the request expect_code: The return code to expect from attempting the whoami request """ - channel = make_request( + channel = await make_request( self.reactor, self.site, "GET", @@ -714,7 +714,7 @@ class RestHelper: issuer=issuer, ) - def login_via_oidc( + async def login_via_oidc( self, fake_server: FakeOidcServer, remote_user_id: str, @@ -735,7 +735,7 @@ class RestHelper: """ client_redirect_url = "https://x" userinfo = {"sub": remote_user_id} - channel, grant = self.auth_via_oidc( + channel, grant = await self.auth_via_oidc( fake_server, userinfo, client_redirect_url, @@ -754,9 +754,9 @@ class RestHelper: assert m, channel.text_body login_token = m.group(1) - return self.login_via_token(login_token, expected_status), grant + return await self.login_via_token(login_token, expected_status), grant - def login_via_token( + async def login_via_token( self, login_token: str, expected_status: int = 200, @@ -774,7 +774,7 @@ class RestHelper: the normal places. """ - channel = make_request( + channel = await make_request( self.reactor, self.site, "POST", @@ -786,7 +786,7 @@ class RestHelper: ) return channel.json_body - def auth_via_oidc( + async def auth_via_oidc( self, fake_server: FakeOidcServer, user_info_dict: JsonDict, @@ -834,10 +834,10 @@ class RestHelper: if ui_auth_session_id: # can't set the client redirect url for UI Auth assert client_redirect_url is None - oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + oauth_uri = await self.initiate_sso_ui_auth(ui_auth_session_id, cookies) else: # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login( + oauth_uri = await self.initiate_sso_login( client_redirect_url, cookies, idp_id=idp_id ) @@ -850,11 +850,11 @@ class RestHelper: assert oauth_uri_path == fake_server.authorization_endpoint, ( "unexpected SSO URI " + oauth_uri_path ) - return self.complete_oidc_auth( + return await self.complete_oidc_auth( fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid ) - def complete_oidc_auth( + async def complete_oidc_auth( self, fake_serer: FakeOidcServer, oauth_uri: str, @@ -904,7 +904,7 @@ class RestHelper: with fake_serer.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code - channel = make_request( + channel = await make_request( self.reactor, self.site, "GET", @@ -915,7 +915,7 @@ class RestHelper: ) return channel, grant - def initiate_sso_login( + async def initiate_sso_login( self, client_redirect_url: str | None, cookies: MutableMapping[str, str], @@ -948,7 +948,7 @@ class RestHelper: # hit the redirect url (which should redirect back to the redirect url. This # is the easiest way of figuring out what the Host header ought to be set to # to keep Synapse happy. - channel = make_request( + channel = await make_request( self.reactor, self.site, "GET", @@ -967,7 +967,7 @@ class RestHelper: location = get_location(channel) parts = urllib.parse.urlsplit(location) next_uri = urllib.parse.urlunsplit(("", "") + parts[2:]) - channel = make_request( + channel = await make_request( self.reactor, self.site, "GET", @@ -981,7 +981,7 @@ class RestHelper: channel.extract_cookies(cookies) return get_location(channel) - def initiate_sso_ui_auth( + async def initiate_sso_ui_auth( self, ui_auth_session_id: str, cookies: MutableMapping[str, str] ) -> str: """Make a request to the ui-auth-via-sso endpoint, and return the target @@ -1001,7 +1001,7 @@ class RestHelper: + urllib.parse.urlencode({"session": ui_auth_session_id}) ) # hit the redirect url (which will issue a cookie and state) - channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint) + channel = await make_request(self.reactor, self.site, "GET", sso_redirect_endpoint) # that should serve a confirmation page assert channel.code == HTTPStatus.OK, channel.text_body channel.extract_cookies(cookies) @@ -1014,9 +1014,9 @@ class RestHelper: oauth_uri = p.links[0] return oauth_uri - def send_read_receipt(self, room_id: str, event_id: str, *, tok: str) -> None: + async def send_read_receipt(self, room_id: str, event_id: str, *, tok: str) -> None: """Send a read receipt into the room at the given event""" - channel = make_request( + channel = await make_request( self.reactor, self.site, method="POST", diff --git a/tests/server.py b/tests/server.py index ce5d5cadf6..945c55cc27 100644 --- a/tests/server.py +++ b/tests/server.py @@ -538,10 +538,26 @@ async def make_request( req.dispatch(root_resource) if await_result and req.render_deferred is not None: - # Await the handler task directly — the event loop drives - # all concurrent tasks (DB operations, background processes) - # naturally without needing nest_asyncio. - await req.render_deferred + import asyncio + + # Advance fake time in a background task so that any + # clock.sleep() calls in the handler (e.g., ratelimit pauses) + # get resolved. We advance by 0.1s per tick. + async def _advance_time() -> None: + while not req.render_deferred.done(): + if clock is not None: + clock.advance(0.1) + await asyncio.sleep(0.01) + + advancer = asyncio.ensure_future(_advance_time()) + try: + await req.render_deferred + finally: + advancer.cancel() + try: + await advancer + except asyncio.CancelledError: + pass return channel