convert login tests to asyncio and make them work with faketime

This commit is contained in:
Matthew Hodgson
2026-03-23 10:36:06 +00:00
parent ac2fb5cacd
commit 582be03d65
3 changed files with 238 additions and 222 deletions

View File

@@ -206,11 +206,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
},
}
)
def test_POST_ratelimiting_per_address(self) -> None:
async def test_POST_ratelimiting_per_address(self) -> None:
# Create different users so we're sure not to be bothered by the per-user
# ratelimiter.
for i in range(6):
self.register_user("kermit" + str(i), "monkey")
await self.register_user("kermit" + str(i), "monkey")
for i in range(6):
params = {
@@ -218,7 +218,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
channel = self.make_request(b"POST", LOGIN_URL, params)
channel = await self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
self.assertEqual(channel.code, 429, msg=channel.result)
@@ -240,7 +240,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
channel = self.make_request(b"POST", LOGIN_URL, params)
channel = await self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.code, 200, msg=channel.result)
@@ -257,8 +257,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
},
}
)
def test_POST_ratelimiting_per_account(self) -> None:
self.register_user("kermit", "monkey")
async def test_POST_ratelimiting_per_account(self) -> None:
await self.register_user("kermit", "monkey")
for i in range(6):
params = {
@@ -266,7 +266,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
channel = self.make_request(b"POST", LOGIN_URL, params)
channel = await self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
self.assertEqual(channel.code, 429, msg=channel.result)
@@ -288,7 +288,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
channel = self.make_request(b"POST", LOGIN_URL, params)
channel = await self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.code, 200, msg=channel.result)
@@ -305,8 +305,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
},
}
)
def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
self.register_user("kermit", "monkey")
async def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
await self.register_user("kermit", "monkey")
for i in range(6):
params = {
@@ -314,7 +314,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
channel = self.make_request(b"POST", LOGIN_URL, params)
channel = await self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
self.assertEqual(channel.code, 429, msg=channel.result)
@@ -336,7 +336,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
channel = self.make_request(b"POST", LOGIN_URL, params)
channel = await self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.code, 403, msg=channel.result)
@@ -430,57 +430,57 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
self.register_user("kermit", "monkey")
async def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
await self.register_user("kermit", "monkey")
# log in as normal
access_token = self.login("kermit", "monkey")
access_token = await self.login("kermit", "monkey")
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
# Now try to hard logout this session
channel = self.make_request(b"POST", "/logout", access_token=access_token)
channel = await self.make_request(b"POST", "/logout", access_token=access_token)
self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
async def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
self,
) -> None:
self.register_user("kermit", "monkey")
await self.register_user("kermit", "monkey")
# log in as normal
access_token = self.login("kermit", "monkey")
access_token = await self.login("kermit", "monkey")
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
# Now try to hard log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
channel = await self.make_request(b"POST", "/logout/all", access_token=access_token)
self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_with_overly_long_device_id_fails(self) -> None:
self.register_user("mickey", "cheese")
async def test_login_with_overly_long_device_id_fails(self) -> None:
await self.register_user("mickey", "cheese")
# create a device_id longer than 512 characters
device_id = "yolo" * 512
@@ -493,7 +493,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
# make a login request with the bad device_id
channel = self.make_request(
channel = await self.make_request(
"POST",
"/_matrix/client/v3/login",
body,
@@ -514,8 +514,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
}
)
def test_require_approval(self) -> None:
channel = self.make_request(
async def test_require_approval(self) -> None:
channel = await self.make_request(
"POST",
"register",
{
@@ -535,25 +535,25 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
channel = self.make_request("POST", LOGIN_URL, params)
channel = await self.make_request("POST", LOGIN_URL, params)
self.assertEqual(403, channel.code, channel.result)
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
self.assertEqual(
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)
def test_get_login_flows_with_login_via_existing_disabled(self) -> None:
async def test_get_login_flows_with_login_via_existing_disabled(self) -> None:
"""GET /login should return m.login.token without get_login_token"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
channel = await self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
flows = {flow["type"]: flow for flow in channel.json_body["flows"]}
self.assertNotIn("m.login.token", flows)
@override_config({"login_via_existing_session": {"enabled": True}})
def test_get_login_flows_with_login_via_existing_enabled(self) -> None:
async def test_get_login_flows_with_login_via_existing_enabled(self) -> None:
"""GET /login should return m.login.token with get_login_token true"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
channel = await self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
self.assertCountEqual(
@@ -576,13 +576,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
]
}
)
def test_spam_checker_allow(self) -> None:
async def test_spam_checker_allow(self) -> None:
"""Check that that adding a spam checker doesn't break login."""
self.register_user("kermit", "monkey")
await self.register_user("kermit", "monkey")
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
channel = self.make_request(
channel = await self.make_request(
"POST",
"/_matrix/client/r0/login",
body,
@@ -600,14 +600,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
]
}
)
def test_spam_checker_deny(self) -> None:
async def test_spam_checker_deny(self) -> None:
"""Check that login"""
self.register_user("kermit", "monkey")
await self.register_user("kermit", "monkey")
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
channel = self.make_request(
channel = await self.make_request(
"POST",
"/_matrix/client/r0/login",
body,
@@ -677,9 +677,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
d.update(build_synapse_client_resource_tree(self.hs))
return d
def test_get_login_flows(self) -> None:
async def test_get_login_flows(self) -> None:
"""GET /login should return password and SSO flows"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
channel = await self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
expected_flow_types = [
@@ -704,17 +704,17 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
],
)
def test_multi_sso_redirect(self) -> None:
async def test_multi_sso_redirect(self) -> None:
"""/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker
channel = self._make_sso_redirect_request(None)
channel = await self._make_sso_redirect_request(None)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
uri = location_headers[0]
# hitting that picker should give us some HTML
channel = self.make_request("GET", uri)
channel = await self.make_request("GET", uri)
self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
@@ -734,10 +734,10 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
def test_multi_sso_redirect_to_cas(self) -> None:
async def test_multi_sso_redirect_to_cas(self) -> None:
"""If CAS is chosen, should redirect to the CAS server"""
channel = self.make_request(
channel = await self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
@@ -758,7 +758,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
)
# follow the redirect
channel = self.make_request(
channel = await self.make_request(
"GET",
# We have to make this relative to be compatible with `make_request(...)`
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
@@ -787,9 +787,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
service_uri_params = urllib.parse.parse_qs(service_uri_query)
self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
def test_multi_sso_redirect_to_saml(self) -> None:
async def test_multi_sso_redirect_to_saml(self) -> None:
"""If SAML is chosen, should redirect to the SAML server"""
channel = self.make_request(
channel = await self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
@@ -809,7 +809,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
)
# follow the redirect
channel = self.make_request(
channel = await self.make_request(
"GET",
# We have to make this relative to be compatible with `make_request(...)`
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
@@ -835,14 +835,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
relay_state_param = saml_uri_params["RelayState"][0]
self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
def test_login_via_oidc(self) -> None:
async def test_login_via_oidc(self) -> None:
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
fake_oidc_server = self.helper.fake_oidc_server()
with fake_oidc_server.patch_homeserver(hs=self.hs):
# pick the default OIDC provider
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/_synapse/client/pick_idp?redirectUrl={urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)}&idp=oidc",
)
@@ -861,7 +861,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
with fake_oidc_server.patch_homeserver(hs=self.hs):
# follow the redirect
channel = self.make_request(
channel = await self.make_request(
"GET",
# We have to make this relative to be compatible with `make_request(...)`
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
@@ -897,7 +897,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
TEST_CLIENT_REDIRECT_URL,
)
channel, _ = self.helper.complete_oidc_auth(
channel, _ = await self.helper.complete_oidc_auth(
fake_oidc_server, oidc_uri, cookies, {"sub": "user1"}
)
@@ -925,7 +925,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.
login_token = params[2][1]
chan = self.make_request(
chan = await self.make_request(
"POST",
"/login",
content={"type": "m.login.token", "token": login_token},
@@ -933,15 +933,15 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_unknown_idp(self) -> None:
async def test_multi_sso_redirect_unknown_idp(self) -> None:
"""An unknown IdP should cause a 400 bad request error"""
channel = self.make_request(
channel = await self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
self.assertEqual(channel.code, 400, channel.result)
def test_multi_sso_redirect_unknown_idp_as_url(self) -> None:
async def test_multi_sso_redirect_unknown_idp_as_url(self) -> None:
"""
An unknown IdP that looks like a URL should cause a 400 bad request error (to
avoid open redirects).
@@ -954,24 +954,24 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
in the URL building tests to cover this case but is only a unit test vs
something at the REST layer here that covers things end-to-end.
"""
channel = self.make_request(
channel = await self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl=something&idp=https://element.io/",
)
self.assertEqual(channel.code, 400, channel.result)
def test_client_idp_redirect_to_unknown(self) -> None:
async def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404"""
channel = self._make_sso_redirect_request("xxx")
channel = await self._make_sso_redirect_request("xxx")
self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_client_idp_redirect_to_oidc(self) -> None:
async def test_client_idp_redirect_to_oidc(self) -> None:
"""If the client pick a known IdP, redirect to it"""
fake_oidc_server = self.helper.fake_oidc_server()
with fake_oidc_server.patch_homeserver(hs=self.hs):
channel = self._make_sso_redirect_request("oidc")
channel = await self._make_sso_redirect_request("oidc")
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
@@ -981,7 +981,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
def _make_sso_redirect_request(self, idp_prov: str | None = None) -> FakeChannel:
async def _make_sso_redirect_request(self, idp_prov: str | None = None) -> FakeChannel:
"""Send a request to /_matrix/client/r0/login/sso/redirect
... possibly specifying an IDP provider
@@ -991,7 +991,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
endpoint += "/" + idp_prov
endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
return self.make_request(
return await self.make_request(
"GET",
endpoint,
custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)],
@@ -1011,7 +1011,7 @@ class CASTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.base_url = "https://matrix.goodserver.com/"
self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
@@ -1054,7 +1054,7 @@ class CASTestCase(unittest.HomeserverTestCase):
mocked_http_client = Mock(spec=["get_raw"])
mocked_http_client.get_raw.side_effect = get_raw
self.hs = self.setup_test_homeserver(
self.hs = await self.setup_test_homeserver(
config=config,
proxied_http_client=mocked_http_client,
)
@@ -1064,7 +1064,7 @@ class CASTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.deactivate_account_handler = hs.get_deactivate_account_handler()
def test_cas_redirect_confirm(self) -> None:
async def test_cas_redirect_confirm(self) -> None:
"""Tests that the SSO login flow serves a confirmation page before redirecting a
user to the redirect URL.
"""
@@ -1079,7 +1079,7 @@ class CASTestCase(unittest.HomeserverTestCase):
cas_ticket_url = urllib.parse.urlunparse(url_parts)
# Get Synapse to call the fake CAS and serve the template.
channel = self.make_request("GET", cas_ticket_url)
channel = await self.make_request("GET", cas_ticket_url)
# Test that the response is HTML.
self.assertEqual(channel.code, 200, channel.result)
@@ -1105,15 +1105,15 @@ class CASTestCase(unittest.HomeserverTestCase):
}
}
)
def test_cas_redirect_whitelisted(self) -> None:
async def test_cas_redirect_whitelisted(self) -> None:
"""Tests that the SSO login flow serves a redirect to a whitelisted url"""
self._test_redirect("https://legit-site.com/")
await self._test_redirect("https://legit-site.com/")
@override_config({"public_baseurl": "https://example.com"})
def test_cas_redirect_login_fallback(self) -> None:
self._test_redirect("https://example.com/_matrix/static/client/login")
async def test_cas_redirect_login_fallback(self) -> None:
await self._test_redirect("https://example.com/_matrix/static/client/login")
def _test_redirect(self, redirect_url: str) -> None:
async def _test_redirect(self, redirect_url: str) -> None:
"""Tests that the SSO login flow serves a redirect for the given redirect URL."""
cas_ticket_url = (
"/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
@@ -1121,7 +1121,7 @@ class CASTestCase(unittest.HomeserverTestCase):
)
# Get Synapse to call the fake CAS and serve the template.
channel = self.make_request("GET", cas_ticket_url)
channel = await self.make_request("GET", cas_ticket_url)
self.assertEqual(channel.code, 302)
location_headers = channel.headers.getRawHeaders("Location")
@@ -1129,15 +1129,15 @@ class CASTestCase(unittest.HomeserverTestCase):
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
@override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
def test_deactivated_user(self) -> None:
async def test_deactivated_user(self) -> None:
"""Logging in as a deactivated account should error."""
redirect_url = "https://legit-site.com/"
# First login (to create the user).
self._test_redirect(redirect_url)
await self._test_redirect(redirect_url)
# Deactivate the account.
self.get_success(
await self.get_success(
self.deactivate_account_handler.deactivate_account(
self.user_id, False, create_requester(self.user_id)
)
@@ -1150,7 +1150,7 @@ class CASTestCase(unittest.HomeserverTestCase):
)
# Get Synapse to call the fake CAS and serve the template.
channel = self.make_request("GET", cas_ticket_url)
channel = await self.make_request("GET", cas_ticket_url)
# Because the user is deactivated they are served an error template.
self.assertEqual(channel.code, 403)
@@ -1187,24 +1187,24 @@ class JWTTestCase(unittest.HomeserverTestCase):
result: bytes = jwt.encode(header, payload, secret)
return result.decode("ascii")
def jwt_login(self, *args: Any) -> FakeChannel:
async def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
channel = self.make_request(b"POST", LOGIN_URL, params)
channel = await self.make_request(b"POST", LOGIN_URL, params)
return channel
def test_login_jwt_valid_registered(self) -> None:
self.register_user("kermit", "monkey")
channel = self.jwt_login({"sub": "kermit"})
async def test_login_jwt_valid_registered(self) -> None:
await self.register_user("kermit", "monkey")
channel = await self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_valid_unregistered(self) -> None:
channel = self.jwt_login({"sub": "frog"})
async def test_login_jwt_valid_unregistered(self) -> None:
channel = await self.jwt_login({"sub": "frog"})
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, "notsecret")
async def test_login_jwt_invalid_signature(self) -> None:
channel = await self.jwt_login({"sub": "frog"}, "notsecret")
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
@@ -1212,8 +1212,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
"JWT validation failed: Signature verification failed",
)
def test_login_jwt_expired(self) -> None:
channel = self.jwt_login({"sub": "frog", "exp": 864000})
async def test_login_jwt_expired(self) -> None:
channel = await self.jwt_login({"sub": "frog", "exp": 864000})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
@@ -1221,9 +1221,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
"JWT validation failed: expired_token: The token is expired",
)
def test_login_jwt_not_before(self) -> None:
async def test_login_jwt_not_before(self) -> None:
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
channel = await self.jwt_login({"sub": "frog", "nbf": now + 3600})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
@@ -1231,22 +1231,22 @@ class JWTTestCase(unittest.HomeserverTestCase):
"JWT validation failed: invalid_token: The token is not valid yet",
)
def test_login_no_sub(self) -> None:
channel = self.jwt_login({"username": "root"})
async def test_login_no_sub(self) -> None:
channel = await self.jwt_login({"username": "root"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
@override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
def test_login_iss(self) -> None:
async def test_login_iss(self) -> None:
"""Test validating the issuer claim."""
# A valid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
channel = await self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
channel = await self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertRegex(
@@ -1255,7 +1255,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
)
# Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"})
channel = await self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertRegex(
@@ -1263,22 +1263,22 @@ class JWTTestCase(unittest.HomeserverTestCase):
r"^JWT validation failed: missing_claim: Missing [\"']iss[\"'] claim$",
)
def test_login_iss_no_config(self) -> None:
async def test_login_iss_no_config(self) -> None:
"""Test providing an issuer claim without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
channel = await self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
def test_login_aud(self) -> None:
async def test_login_aud(self) -> None:
"""Test validating the audience claim."""
# A valid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
channel = await self.jwt_login({"sub": "kermit", "aud": "test-audience"})
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
channel = await self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertRegex(
@@ -1287,7 +1287,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
)
# Not providing an audience.
channel = self.jwt_login({"sub": "kermit"})
channel = await self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertRegex(
@@ -1295,9 +1295,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
r"^JWT validation failed: missing_claim: Missing [\"']aud[\"'] claim$",
)
def test_login_aud_no_config(self) -> None:
async def test_login_aud_no_config(self) -> None:
"""Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
channel = await self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertRegex(
@@ -1305,36 +1305,36 @@ class JWTTestCase(unittest.HomeserverTestCase):
r"^JWT validation failed: invalid_claim: Invalid claim [\"']aud[\"']$",
)
def test_login_default_sub(self) -> None:
async def test_login_default_sub(self) -> None:
"""Test reading user ID from the default subject claim."""
channel = self.jwt_login({"sub": "kermit"})
channel = await self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
def test_login_custom_sub(self) -> None:
async def test_login_custom_sub(self) -> None:
"""Test reading user ID from a custom subject claim."""
channel = self.jwt_login({"username": "frog"})
channel = await self.jwt_login({"username": "frog"})
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
@override_config(
{"jwt_config": {**base_config, "display_name_claim": "display_name"}}
)
def test_login_custom_display_name(self) -> None:
async def test_login_custom_display_name(self) -> None:
"""Test setting a custom display name."""
localpart = "pinkie"
user_id = f"@{localpart}:test"
display_name = "Pinkie Pie"
# Perform the login, specifying a custom display name.
channel = self.jwt_login({"sub": localpart, "display_name": display_name})
channel = await self.jwt_login({"sub": localpart, "display_name": display_name})
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], user_id)
# Fetch the user's display name and check that it was set correctly.
access_token = channel.json_body["access_token"]
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/_matrix/client/v3/profile/{user_id}/displayname",
access_token=access_token,
@@ -1342,23 +1342,23 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["displayname"], display_name)
def test_login_no_token(self) -> None:
async def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
channel = await self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
def test_deactivated_user(self) -> None:
async def test_deactivated_user(self) -> None:
"""Logging in as a deactivated account should error."""
user_id = self.register_user("kermit", "monkey")
self.get_success(
user_id = await self.register_user("kermit", "monkey")
await self.get_success(
self.hs.get_deactivate_account_handler().deactivate_account(
user_id, erase_data=False, requester=create_requester(user_id)
)
)
channel = self.jwt_login({"sub": "kermit"})
channel = await self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_USER_DEACTIVATED")
self.assertEqual(
@@ -1436,18 +1436,18 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
result: bytes = jwt.encode(header, payload, secret)
return result.decode("ascii")
def jwt_login(self, *args: Any) -> FakeChannel:
async def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
channel = self.make_request(b"POST", LOGIN_URL, params)
channel = await self.make_request(b"POST", LOGIN_URL, params)
return channel
def test_login_jwt_valid(self) -> None:
channel = self.jwt_login({"sub": "kermit"})
async def test_login_jwt_valid(self) -> None:
channel = await self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
async def test_login_jwt_invalid_signature(self) -> None:
channel = await self.jwt_login({"sub": "frog"}, self.bad_privatekey)
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
@@ -1465,8 +1465,8 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
register.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver()
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = await self.setup_test_homeserver()
self.service = ApplicationService(
id="unique_identifier",
@@ -1511,23 +1511,23 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.services_cache.append(self.msc4190_service)
return self.hs
def test_login_appservice_user(self) -> None:
async def test_login_appservice_user(self) -> None:
"""Test that an appservice user can use /login"""
self.register_appservice_user(AS_USER, self.service.token)
await self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": AS_USER},
}
channel = self.make_request(
channel = await self.make_request(
b"POST", LOGIN_URL, params, access_token=self.service.token
)
self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_msc4190_fail(self) -> None:
async def test_login_appservice_msc4190_fail(self) -> None:
"""Test that an appservice user can use /login"""
self.register_appservice_user(
await self.register_appservice_user(
"as3_user_alice", self.msc4190_service.token, inhibit_login=True
)
@@ -1535,7 +1535,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": "as3_user_alice"},
}
channel = self.make_request(
channel = await self.make_request(
b"POST", LOGIN_URL, params, access_token=self.msc4190_service.token
)
@@ -1546,9 +1546,9 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
channel.json_body,
)
def test_login_appservice_user_bot(self) -> None:
async def test_login_appservice_user_bot(self) -> None:
"""Test that the appservice bot can use /login"""
self.register_appservice_user(AS_USER, self.service.token)
await self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
@@ -1557,51 +1557,51 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
"user": self.service.sender.to_string(),
},
}
channel = self.make_request(
channel = await self.make_request(
b"POST", LOGIN_URL, params, access_token=self.service.token
)
self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_wrong_user(self) -> None:
async def test_login_appservice_wrong_user(self) -> None:
"""Test that non-as users cannot login with the as token"""
self.register_appservice_user(AS_USER, self.service.token)
await self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": "fibble_wibble"},
}
channel = self.make_request(
channel = await self.make_request(
b"POST", LOGIN_URL, params, access_token=self.service.token
)
self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_wrong_as(self) -> None:
async def test_login_appservice_wrong_as(self) -> None:
"""Test that as users cannot login with wrong as token"""
self.register_appservice_user(AS_USER, self.service.token)
await self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": AS_USER},
}
channel = self.make_request(
channel = await self.make_request(
b"POST", LOGIN_URL, params, access_token=self.another_service.token
)
self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_no_token(self) -> None:
async def test_login_appservice_no_token(self) -> None:
"""Test that users must provide a token when using the appservice
login method
"""
self.register_appservice_user(AS_USER, self.service.token)
await self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": AS_USER},
}
channel = self.make_request(b"POST", LOGIN_URL, params)
channel = await self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.code, 401, msg=channel.result)
@@ -1616,10 +1616,10 @@ class UsernamePickerTestCase(HomeserverTestCase):
account.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock(spec=["get_file"])
self.http_client.get_file.side_effect = mock_get_file
hs = self.setup_test_homeserver(
hs = await self.setup_test_homeserver(
proxied_blocklisted_http_client=self.http_client
)
return hs
@@ -1648,7 +1648,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
d.update(build_synapse_client_resource_tree(self.hs))
return d
def proceed_to_username_picker_page(
async def proceed_to_username_picker_page(
self,
fake_oidc_server: FakeOidcServer,
displayname: str,
@@ -1656,7 +1656,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
picture: str,
) -> tuple[str, str]:
# do the start of the login flow
channel, _ = self.helper.auth_via_oidc(
channel, _ = await self.helper.auth_via_oidc(
fake_oidc_server,
{
"sub": "tester",
@@ -1701,7 +1701,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
return picker_url, session_id
def test_username_picker_use_displayname_avatar_and_email(self) -> None:
async def test_username_picker_use_displayname_avatar_and_email(self) -> None:
"""Test the happy path of a username picker flow with using displayname, avatar and email."""
fake_oidc_server = self.helper.fake_oidc_server()
@@ -1711,7 +1711,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
email = "bobby@test.com"
picture = "mxc://test/avatar_url"
picker_url, session_id = self.proceed_to_username_picker_page(
picker_url, session_id = await self.proceed_to_username_picker_page(
fake_oidc_server, displayname, email, picture
)
@@ -1726,7 +1726,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
b"use_email": email,
}
).encode("utf8")
chan = self.make_request(
chan = await self.make_request(
"POST",
path=picker_url,
content=content,
@@ -1740,7 +1740,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
assert location_headers
# send a request to the completion page, which should 302 to the client redirectUrl
chan = self.make_request(
chan = await self.make_request(
"GET",
path=location_headers[0],
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
@@ -1765,7 +1765,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.
chan = self.make_request(
chan = await self.make_request(
"POST",
"/login",
content={"type": "m.login.token", "token": login_token},
@@ -1774,7 +1774,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertEqual(chan.json_body["user_id"], mxid)
# ensure the displayname and avatar from the OIDC response have been configured for the user.
channel = self.make_request(
channel = await self.make_request(
"GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
)
self.assertEqual(channel.code, 200, channel.result)
@@ -1782,13 +1782,13 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertEqual(displayname, channel.json_body["displayname"])
# ensure the email from the OIDC response has been configured for the user.
channel = self.make_request(
channel = await self.make_request(
"GET", "/account/3pid", access_token=chan.json_body["access_token"]
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(email, channel.json_body["threepids"][0]["address"])
def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None:
async def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None:
"""Test the happy path of a username picker flow without using displayname, avatar or email."""
fake_oidc_server = self.helper.fake_oidc_server()
@@ -1799,7 +1799,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
picture = "mxc://test/avatar_url"
username = "bobby"
picker_url, session_id = self.proceed_to_username_picker_page(
picker_url, session_id = await self.proceed_to_username_picker_page(
fake_oidc_server, displayname, email, picture
)
@@ -1813,7 +1813,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
b"use_avatar": b"false",
}
).encode("utf8")
chan = self.make_request(
chan = await self.make_request(
"POST",
path=picker_url,
content=content,
@@ -1827,7 +1827,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
assert location_headers
# send a request to the completion page, which should 302 to the client redirectUrl
chan = self.make_request(
chan = await self.make_request(
"GET",
path=location_headers[0],
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
@@ -1852,7 +1852,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.
chan = self.make_request(
chan = await self.make_request(
"POST",
"/login",
content={"type": "m.login.token", "token": login_token},
@@ -1861,7 +1861,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertEqual(chan.json_body["user_id"], mxid)
# ensure the displayname and avatar from the OIDC response have not been configured for the user.
channel = self.make_request(
channel = await self.make_request(
"GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
)
self.assertEqual(channel.code, 200, channel.result)
@@ -1869,7 +1869,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertEqual(username, channel.json_body["displayname"])
# ensure the email from the OIDC response has not been configured for the user.
channel = self.make_request(
channel = await self.make_request(
"GET", "/account/3pid", access_token=chan.json_body["access_token"]
)
self.assertEqual(channel.code, 200, channel.result)

View File

@@ -81,7 +81,7 @@ class RestHelper:
auth_user_id: str | None
@overload
def create_room_as(
async def create_room_as(
self,
room_creator: str | None = ...,
is_public: bool | None = ...,
@@ -93,7 +93,7 @@ class RestHelper:
) -> str: ...
@overload
def create_room_as(
async def create_room_as(
self,
room_creator: str | None = ...,
is_public: bool | None = ...,
@@ -104,7 +104,7 @@ class RestHelper:
custom_headers: Iterable[tuple[AnyStr, AnyStr]] | None = ...,
) -> str | None: ...
def create_room_as(
async def create_room_as(
self,
room_creator: str | None = None,
is_public: bool | None = True,
@@ -149,7 +149,7 @@ class RestHelper:
if tok:
path = path + "?access_token=%s" % tok
channel = make_request(
channel = await make_request(
self.reactor,
self.site,
"POST",
@@ -166,7 +166,7 @@ class RestHelper:
else:
return None
def invite(
async def invite(
self,
room: str,
src: str | None = None,
@@ -175,7 +175,7 @@ class RestHelper:
tok: str | None = None,
extra_data: dict | None = None,
) -> JsonDict:
return self.change_membership(
return await self.change_membership(
room=room,
src=src,
targ=targ,
@@ -185,7 +185,7 @@ class RestHelper:
extra_data=extra_data,
)
def join(
async def join(
self,
room: str,
user: str,
@@ -195,7 +195,7 @@ class RestHelper:
expect_errcode: Codes | None = None,
expect_additional_fields: dict | None = None,
) -> JsonDict:
return self.change_membership(
return await self.change_membership(
room=room,
src=user,
targ=user,
@@ -207,7 +207,7 @@ class RestHelper:
expect_additional_fields=expect_additional_fields,
)
def knock(
async def knock(
self,
room: str | None = None,
user: str | None = None,
@@ -225,7 +225,7 @@ class RestHelper:
if reason:
data["reason"] = reason
channel = make_request(
channel = await make_request(
self.reactor,
self.site,
"POST",
@@ -241,14 +241,14 @@ class RestHelper:
self.auth_user_id = temp_id
def leave(
async def leave(
self,
room: str,
user: str | None = None,
expect_code: int = HTTPStatus.OK,
tok: str | None = None,
) -> JsonDict:
return self.change_membership(
return await self.change_membership(
room=room,
src=user,
targ=user,
@@ -257,7 +257,7 @@ class RestHelper:
expect_code=expect_code,
)
def ban(
async def ban(
self,
room: str,
src: str,
@@ -266,7 +266,7 @@ class RestHelper:
tok: str | None = None,
) -> JsonDict:
"""A convenience helper: `change_membership` with `membership` preset to "ban"."""
return self.change_membership(
return await self.change_membership(
room=room,
src=src,
targ=targ,
@@ -275,7 +275,7 @@ class RestHelper:
expect_code=expect_code,
)
def change_membership(
async def change_membership(
self,
room: str,
src: str | None,
@@ -325,7 +325,7 @@ class RestHelper:
data = {"membership": membership}
data.update(extra_data or {})
channel = make_request(
channel = await make_request(
self.reactor,
self.site,
"PUT",
@@ -372,7 +372,7 @@ class RestHelper:
self.auth_user_id = temp_id
return channel.json_body
def send(
async def send(
self,
room_id: str,
body: str | None = None,
@@ -387,7 +387,7 @@ class RestHelper:
content = {"msgtype": "m.text", "body": body}
return self.send_event(
return await self.send_event(
room_id,
type,
content,
@@ -397,7 +397,7 @@ class RestHelper:
custom_headers=custom_headers,
)
def send_messages(
async def send_messages(
self,
room_id: str,
num_events: int,
@@ -413,7 +413,7 @@ class RestHelper:
event_ids = []
for event_index in range(num_events):
response = self.send_event(
response = await self.send_event(
room_id,
EventTypes.Message,
content_fn(event_index),
@@ -423,7 +423,7 @@ class RestHelper:
return event_ids
def send_event(
async def send_event(
self,
room_id: str,
type: str,
@@ -440,7 +440,7 @@ class RestHelper:
if tok:
path = path + "?access_token=%s" % tok
channel = make_request(
channel = await make_request(
self.reactor,
self.site,
"PUT",
@@ -457,7 +457,7 @@ class RestHelper:
return channel.json_body
def send_sticky_event(
async def send_sticky_event(
self,
room_id: str,
type: str,
@@ -480,7 +480,7 @@ class RestHelper:
if tok:
path = path + f"&access_token={tok}"
channel = make_request(
channel = await make_request(
self.reactor,
self.site,
"PUT",
@@ -495,7 +495,7 @@ class RestHelper:
return channel.json_body
def get_event(
async def get_event(
self,
room_id: str,
event_id: str,
@@ -517,7 +517,7 @@ class RestHelper:
if tok:
path = path + f"?access_token={tok}"
channel = make_request(
channel = await make_request(
self.reactor,
self.site,
"GET",
@@ -532,7 +532,7 @@ class RestHelper:
return channel.json_body
def _read_write_state(
async def _read_write_state(
self,
room_id: str,
event_type: str,
@@ -573,7 +573,7 @@ class RestHelper:
if body is not None:
content = json.dumps(body).encode("utf8")
channel = make_request(self.reactor, self.site, method, path, content)
channel = await make_request(self.reactor, self.site, method, path, content)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
@@ -583,7 +583,7 @@ class RestHelper:
return channel.json_body
def get_state(
async def get_state(
self,
room_id: str,
event_type: str,
@@ -606,11 +606,11 @@ class RestHelper:
Raises:
AssertionError: if expect_code doesn't match the HTTP code we received
"""
return self._read_write_state(
return await self._read_write_state(
room_id, event_type, None, tok, expect_code, state_key, method="GET"
)
def send_state(
async def send_state(
self,
room_id: str,
event_type: str,
@@ -635,11 +635,11 @@ class RestHelper:
Raises:
AssertionError: if expect_code doesn't match the HTTP code we received
"""
return self._read_write_state(
return await self._read_write_state(
room_id, event_type, body, tok, expect_code, state_key, method="PUT"
)
def upload_media(
async def upload_media(
self,
image_data: bytes,
tok: str,
@@ -655,7 +655,7 @@ class RestHelper:
expect_code: The return code to expect from attempting to upload the media
"""
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request(
channel = await make_request(
self.reactor,
self.site,
"POST",
@@ -672,7 +672,7 @@ class RestHelper:
return channel.json_body
def whoami(
async def whoami(
self,
access_token: str,
expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK,
@@ -684,7 +684,7 @@ class RestHelper:
access_token: The user token to use during the request
expect_code: The return code to expect from attempting the whoami request
"""
channel = make_request(
channel = await make_request(
self.reactor,
self.site,
"GET",
@@ -714,7 +714,7 @@ class RestHelper:
issuer=issuer,
)
def login_via_oidc(
async def login_via_oidc(
self,
fake_server: FakeOidcServer,
remote_user_id: str,
@@ -735,7 +735,7 @@ class RestHelper:
"""
client_redirect_url = "https://x"
userinfo = {"sub": remote_user_id}
channel, grant = self.auth_via_oidc(
channel, grant = await self.auth_via_oidc(
fake_server,
userinfo,
client_redirect_url,
@@ -754,9 +754,9 @@ class RestHelper:
assert m, channel.text_body
login_token = m.group(1)
return self.login_via_token(login_token, expected_status), grant
return await self.login_via_token(login_token, expected_status), grant
def login_via_token(
async def login_via_token(
self,
login_token: str,
expected_status: int = 200,
@@ -774,7 +774,7 @@ class RestHelper:
the normal places.
"""
channel = make_request(
channel = await make_request(
self.reactor,
self.site,
"POST",
@@ -786,7 +786,7 @@ class RestHelper:
)
return channel.json_body
def auth_via_oidc(
async def auth_via_oidc(
self,
fake_server: FakeOidcServer,
user_info_dict: JsonDict,
@@ -834,10 +834,10 @@ class RestHelper:
if ui_auth_session_id:
# can't set the client redirect url for UI Auth
assert client_redirect_url is None
oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
oauth_uri = await self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
else:
# otherwise, hit the login redirect endpoint
oauth_uri = self.initiate_sso_login(
oauth_uri = await self.initiate_sso_login(
client_redirect_url, cookies, idp_id=idp_id
)
@@ -850,11 +850,11 @@ class RestHelper:
assert oauth_uri_path == fake_server.authorization_endpoint, (
"unexpected SSO URI " + oauth_uri_path
)
return self.complete_oidc_auth(
return await self.complete_oidc_auth(
fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid
)
def complete_oidc_auth(
async def complete_oidc_auth(
self,
fake_serer: FakeOidcServer,
oauth_uri: str,
@@ -904,7 +904,7 @@ class RestHelper:
with fake_serer.patch_homeserver(hs=self.hs):
# now hit the callback URI with the right params and a made-up code
channel = make_request(
channel = await make_request(
self.reactor,
self.site,
"GET",
@@ -915,7 +915,7 @@ class RestHelper:
)
return channel, grant
def initiate_sso_login(
async def initiate_sso_login(
self,
client_redirect_url: str | None,
cookies: MutableMapping[str, str],
@@ -948,7 +948,7 @@ class RestHelper:
# hit the redirect url (which should redirect back to the redirect url. This
# is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy.
channel = make_request(
channel = await make_request(
self.reactor,
self.site,
"GET",
@@ -967,7 +967,7 @@ class RestHelper:
location = get_location(channel)
parts = urllib.parse.urlsplit(location)
next_uri = urllib.parse.urlunsplit(("", "") + parts[2:])
channel = make_request(
channel = await make_request(
self.reactor,
self.site,
"GET",
@@ -981,7 +981,7 @@ class RestHelper:
channel.extract_cookies(cookies)
return get_location(channel)
def initiate_sso_ui_auth(
async def initiate_sso_ui_auth(
self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
) -> str:
"""Make a request to the ui-auth-via-sso endpoint, and return the target
@@ -1001,7 +1001,7 @@ class RestHelper:
+ urllib.parse.urlencode({"session": ui_auth_session_id})
)
# hit the redirect url (which will issue a cookie and state)
channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
channel = await make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
# that should serve a confirmation page
assert channel.code == HTTPStatus.OK, channel.text_body
channel.extract_cookies(cookies)
@@ -1014,9 +1014,9 @@ class RestHelper:
oauth_uri = p.links[0]
return oauth_uri
def send_read_receipt(self, room_id: str, event_id: str, *, tok: str) -> None:
async def send_read_receipt(self, room_id: str, event_id: str, *, tok: str) -> None:
"""Send a read receipt into the room at the given event"""
channel = make_request(
channel = await make_request(
self.reactor,
self.site,
method="POST",

View File

@@ -538,10 +538,26 @@ async def make_request(
req.dispatch(root_resource)
if await_result and req.render_deferred is not None:
# Await the handler task directly — the event loop drives
# all concurrent tasks (DB operations, background processes)
# naturally without needing nest_asyncio.
await req.render_deferred
import asyncio
# Advance fake time in a background task so that any
# clock.sleep() calls in the handler (e.g., ratelimit pauses)
# get resolved. We advance by 0.1s per tick.
async def _advance_time() -> None:
while not req.render_deferred.done():
if clock is not None:
clock.advance(0.1)
await asyncio.sleep(0.01)
advancer = asyncio.ensure_future(_advance_time())
try:
await req.render_deferred
finally:
advancer.cancel()
try:
await advancer
except asyncio.CancelledError:
pass
return channel