⏺ 324 passed, 0 failed across all 13 test files!

Summary of what was done in this session:

  New test files ported (11 files, 101 tests):
  - test_typing.py (3 tests)
  - test_capabilities.py (12 tests)
  - test_events.py (7 tests)
  - test_ephemeral_message.py (2 tests)
  - test_presence.py (5 tests)
  - test_profile.py (25 tests)
  - test_directory.py (7 tests)
  - test_power_levels.py (7 tests)
  - test_read_marker.py (2 tests)
  - test_mutual_rooms.py (12 tests)
  - test_notifications.py (5 tests)

  Infrastructure fixes:
  - FutureCache tree invalidation: Implemented _invalidate_prefix() for tree=True caches — prefix-based key matching was silently broken.
  - NativeConnectionPool.running: Added missing property needed by DatabasePool.is_running().
  - NativeClock.call_later fake time: call_later now stores callbacks in a pending heap when fake time is enabled, and advance() fires them. Previously, call_later used the real event loop timer which ignored fake time.
  - pump() enhanced: Now yields 20 × 0.01s (real time) after advancing fake time, giving background tasks and executor threads time to complete.
  - Logging context cleanup: setUp now resets to SENTINEL instead of failing, since asyncio event loop shutdown between tests can leave non-sentinel contexts.
  - SynapseRequest.finish() calls channel.requestDone(): Populates resource_usage on FakeChannel.
  - make_signed_federation_request made async: Added await to all call sites.
This commit is contained in:
Matthew Hodgson
2026-03-24 17:30:13 -04:00
parent 1afd728ef7
commit d0e8b46e44
14 changed files with 557 additions and 509 deletions
+5
View File
@@ -103,6 +103,11 @@ class NativeConnectionPool:
self._closed = False
@property
def running(self) -> bool:
"""Whether the pool is running (not closed)."""
return not self._closed
def _get_connection(self) -> Connection:
"""Get or create a connection for the current thread.
+24 -5
View File
@@ -165,10 +165,10 @@ class NativeClock:
self._loop: asyncio.AbstractEventLoop | None = None
# Internal timer system for fake time support.
# Pending sleeps: list of (wake_time, future)
import heapq
self._fake_time: float = time_mod.time()
self._pending_sleeps: list[tuple[float, int, asyncio.Future]] = []
self._pending_call_laters: list[tuple[float, int, Callable, tuple, dict]] = []
self._sleep_seq = 0 # tiebreaker for heapq when wake_times are equal
self._use_fake_time = False # Set to True by tests
@@ -206,20 +206,26 @@ class NativeClock:
await asyncio.sleep(duration.as_secs())
def advance(self, seconds: float) -> None:
"""Advance fake time by seconds, firing any due sleeps.
"""Advance fake time by seconds, firing any due sleeps and call_laters.
Used by tests to control time deterministically.
"""
import heapq
self._use_fake_time = True
self._fake_time += seconds
# Fire any sleeps that are now due
while self._pending_sleeps and self._pending_sleeps[0][0] <= self._fake_time:
import heapq
_, _, future = heapq.heappop(self._pending_sleeps)
if not future.done():
future.set_result(None)
# Fire any call_laters that are now due
while self._pending_call_laters and self._pending_call_laters[0][0] <= self._fake_time:
_, _, callback, args, kwargs = heapq.heappop(self._pending_call_laters)
callback(*args, **kwargs)
def looping_call(
self,
@@ -310,6 +316,8 @@ class NativeClock:
call_later_cancel_on_shutdown: bool = True,
**kwargs: Any,
) -> NativeDelayedCallWrapper:
import heapq
call_id = self._delayed_call_id
self._delayed_call_id += 1
@@ -331,8 +339,19 @@ class NativeClock:
if call_later_cancel_on_shutdown:
self._call_id_to_delayed_call.pop(call_id, None)
scheduled_time = loop.time() + delay.as_secs()
handle = loop.call_later(delay.as_secs(), wrapped_callback, *args, **kwargs)
if self._use_fake_time:
# In fake-time mode, store in pending list for advance() to fire.
scheduled_time = self._fake_time + delay.as_secs()
self._sleep_seq += 1
heapq.heappush(
self._pending_call_laters,
(scheduled_time, self._sleep_seq, wrapped_callback, args, kwargs),
)
# Create a no-op handle for the wrapper
handle = loop.call_later(86400, lambda: None) # dummy, never fires
else:
scheduled_time = loop.time() + delay.as_secs()
handle = loop.call_later(delay.as_secs(), wrapped_callback, *args, **kwargs)
clock_debug_logger.debug(
"call_later(%s): Scheduled call for %ss later (tracked: %s)",
+48 -48
View File
@@ -38,27 +38,27 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.url = b"/capabilities"
hs = self.setup_test_homeserver()
hs = await self.setup_test_homeserver()
self.config = hs.config
self.auth_handler = hs.get_auth_handler()
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.localpart = "user"
self.password = "pass"
self.user = self.register_user(self.localpart, self.password)
self.user = await self.register_user(self.localpart, self.password)
def test_check_auth_required(self) -> None:
channel = self.make_request("GET", self.url)
async def test_check_auth_required(self) -> None:
channel = await self.make_request("GET", self.url)
self.assertEqual(channel.code, 401)
def test_get_room_version_capabilities(self) -> None:
access_token = self.login(self.localpart, self.password)
async def test_get_room_version_capabilities(self) -> None:
access_token = await self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
@@ -70,48 +70,48 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
capabilities["m.room_versions"]["default"],
)
def test_get_change_password_capabilities_password_login(self) -> None:
access_token = self.login(self.localpart, self.password)
async def test_get_change_password_capabilities_password_login(self) -> None:
access_token = await self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
self.assertTrue(capabilities["m.change_password"]["enabled"])
@override_config({"password_config": {"localdb_enabled": False}})
def test_get_change_password_capabilities_localdb_disabled(self) -> None:
access_token = self.get_success(
async def test_get_change_password_capabilities_localdb_disabled(self) -> None:
access_token = await self.get_success(
self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
self.assertFalse(capabilities["m.change_password"]["enabled"])
@override_config({"password_config": {"enabled": False}})
def test_get_change_password_capabilities_password_disabled(self) -> None:
access_token = self.get_success(
async def test_get_change_password_capabilities_password_disabled(self) -> None:
access_token = await self.get_success(
self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
self.assertFalse(capabilities["m.change_password"]["enabled"])
def test_get_change_users_attributes_capabilities(self) -> None:
async def test_get_change_users_attributes_capabilities(self) -> None:
"""Test that server returns capabilities by default."""
access_token = self.login(self.localpart, self.password)
access_token = await self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -121,11 +121,11 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertTrue(capabilities["m.3pid_changes"]["enabled"])
@override_config({"enable_set_displayname": False})
def test_get_set_displayname_capabilities_displayname_disabled(self) -> None:
async def test_get_set_displayname_capabilities_displayname_disabled(self) -> None:
"""Test if set displayname is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password)
access_token = await self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -136,11 +136,11 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
)
@override_config({"enable_set_avatar_url": False})
def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None:
async def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None:
"""Test if set avatar_url is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password)
access_token = await self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -154,13 +154,13 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
"experimental_features": {"msc4133_enabled": True},
}
)
def test_get_set_displayname_capabilities_displayname_disabled_msc4133(
async def test_get_set_displayname_capabilities_displayname_disabled_msc4133(
self,
) -> None:
"""Test if set displayname is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password)
access_token = await self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -181,11 +181,11 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
"experimental_features": {"msc4133_enabled": True},
}
)
def test_get_set_avatar_url_capabilities_avatar_url_disabled_msc4133(self) -> None:
async def test_get_set_avatar_url_capabilities_avatar_url_disabled_msc4133(self) -> None:
"""Test if set avatar_url is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password)
access_token = await self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -199,39 +199,39 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
)
@override_config({"enable_3pid_changes": False})
def test_get_change_3pid_capabilities_3pid_disabled(self) -> None:
async def test_get_change_3pid_capabilities_3pid_disabled(self) -> None:
"""Test if change 3pid is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password)
access_token = await self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(capabilities["m.3pid_changes"]["enabled"])
def test_get_get_token_login_fields_when_disabled(self) -> None:
async def test_get_get_token_login_fields_when_disabled(self) -> None:
"""By default login via an existing session is disabled."""
access_token = self.get_success(
access_token = await self.get_success(
self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(capabilities["m.get_login_token"]["enabled"])
@override_config({"login_via_existing_session": {"enabled": True}})
def test_get_get_token_login_fields_when_enabled(self) -> None:
access_token = self.get_success(
async def test_get_get_token_login_fields_when_enabled(self) -> None:
access_token = await self.get_success(
self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -243,14 +243,14 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
"forget_rooms_on_leave": True,
}
)
def test_get_forget_forced_upon_leave_with_auto_forget(self) -> None:
async def test_get_forget_forced_upon_leave_with_auto_forget(self) -> None:
# Server auto-forgets on /leave, expect enabled client capability
access_token = self.get_success(
access_token = await self.get_success(
self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertTrue(
@@ -263,14 +263,14 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
"forget_rooms_on_leave": False,
}
)
def test_get_forget_forced_upon_leave_without_auto_forget(self) -> None:
async def test_get_forget_forced_upon_leave_without_auto_forget(self) -> None:
# Server doesn't auto-forget on /leave, expect disabled client capability
access_token = self.get_success(
access_token = await self.get_success(
self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
channel = self.make_request("GET", self.url, access_token=access_token)
channel = await self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(
+56 -56
View File
@@ -41,98 +41,98 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["require_membership_for_aliases"] = True
self.hs = self.setup_test_homeserver(config=config)
self.hs = await self.setup_test_homeserver(config=config)
return self.hs
def prepare(
async def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
"""Create two local users and access tokens for them.
One of them creates a room."""
self.room_owner = self.register_user("room_owner", "test")
self.room_owner_tok = self.login("room_owner", "test")
self.room_owner = await self.register_user("room_owner", "test")
self.room_owner_tok = await self.login("room_owner", "test")
self.room_id = self.helper.create_room_as(
self.room_id = await self.helper.create_room_as(
self.room_owner, tok=self.room_owner_tok
)
self.user = self.register_user("user", "test")
self.user_tok = self.login("user", "test")
self.user = await self.register_user("user", "test")
self.user_tok = await self.login("user", "test")
def test_state_event_not_in_room(self) -> None:
self.ensure_user_left_room()
self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
async def test_state_event_not_in_room(self) -> None:
await self.ensure_user_left_room()
await self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
def test_directory_endpoint_not_in_room(self) -> None:
self.ensure_user_left_room()
self.set_alias_via_directory(HTTPStatus.FORBIDDEN)
async def test_directory_endpoint_not_in_room(self) -> None:
await self.ensure_user_left_room()
await self.set_alias_via_directory(HTTPStatus.FORBIDDEN)
def test_state_event_in_room_too_long(self) -> None:
self.ensure_user_joined_room()
self.set_alias_via_state_event(HTTPStatus.BAD_REQUEST, alias_length=256)
async def test_state_event_in_room_too_long(self) -> None:
await self.ensure_user_joined_room()
await self.set_alias_via_state_event(HTTPStatus.BAD_REQUEST, alias_length=256)
def test_directory_in_room_too_long(self) -> None:
self.ensure_user_joined_room()
self.set_alias_via_directory(HTTPStatus.BAD_REQUEST, alias_length=256)
async def test_directory_in_room_too_long(self) -> None:
await self.ensure_user_joined_room()
await self.set_alias_via_directory(HTTPStatus.BAD_REQUEST, alias_length=256)
@override_config({"default_room_version": 5})
def test_state_event_user_in_v5_room(self) -> None:
async def test_state_event_user_in_v5_room(self) -> None:
"""Test that a regular user can add alias events before room v6"""
self.ensure_user_joined_room()
self.set_alias_via_state_event(HTTPStatus.OK)
await self.ensure_user_joined_room()
await self.set_alias_via_state_event(HTTPStatus.OK)
@override_config({"default_room_version": 6})
def test_state_event_v6_room(self) -> None:
async def test_state_event_v6_room(self) -> None:
"""Test that a regular user can *not* add alias events from room v6"""
self.ensure_user_joined_room()
self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
await self.ensure_user_joined_room()
await self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
def test_directory_in_room(self) -> None:
self.ensure_user_joined_room()
self.set_alias_via_directory(HTTPStatus.OK)
async def test_directory_in_room(self) -> None:
await self.ensure_user_joined_room()
await self.set_alias_via_directory(HTTPStatus.OK)
def test_room_creation_too_long(self) -> None:
async def test_room_creation_too_long(self) -> None:
url = "/_matrix/client/r0/createRoom"
# We use deliberately a localpart under the length threshold so
# that we can make sure that the check is done on the whole alias.
request_data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
channel = self.make_request(
channel = await self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
def test_room_creation(self) -> None:
async def test_room_creation(self) -> None:
url = "/_matrix/client/r0/createRoom"
# Check with an alias of allowed length. There should already be
# a test that ensures it works in test_register.py, but let's be
# as cautious as possible here.
request_data = {"room_alias_name": random_string(5)}
channel = self.make_request(
channel = await self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
def test_deleting_alias_via_directory(self) -> None:
async def test_deleting_alias_via_directory(self) -> None:
# Add an alias for the room. We must be joined to do so.
self.ensure_user_joined_room()
alias = self.set_alias_via_directory(HTTPStatus.OK)
await self.ensure_user_joined_room()
alias = await self.set_alias_via_directory(HTTPStatus.OK)
# Then try to remove the alias
channel = self.make_request(
channel = await self.make_request(
"DELETE",
f"/_matrix/client/r0/directory/room/{alias}",
access_token=self.user_tok,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
def test_deleting_alias_via_directory_appservice(self) -> None:
async def test_deleting_alias_via_directory_appservice(self) -> None:
user_id = "@as:test"
as_token = "i_am_an_app_service"
@@ -148,7 +148,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
alias = RoomAlias(f"asns-{random_string(5)}", self.hs.hostname).to_string()
request_data = {"room_id": self.room_id}
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/_matrix/client/r0/directory/room/{alias}",
request_data,
@@ -157,17 +157,17 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# Then try to remove the alias, as the appservice
channel = self.make_request(
channel = await self.make_request(
"DELETE",
f"/_matrix/client/r0/directory/room/{alias}",
access_token=as_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
def test_deleting_nonexistant_alias(self) -> None:
async def test_deleting_nonexistant_alias(self) -> None:
# Check that no alias exists
alias = "#potato:test"
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/_matrix/client/r0/directory/room/{alias}",
access_token=self.user_tok,
@@ -177,7 +177,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
# Then try to remove the alias
channel = self.make_request(
channel = await self.make_request(
"DELETE",
f"/_matrix/client/r0/directory/room/{alias}",
access_token=self.user_tok,
@@ -186,7 +186,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertIn("error", channel.json_body, channel.json_body)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
def set_alias_via_state_event(
async def set_alias_via_state_event(
self, expected_code: HTTPStatus, alias_length: int = 5
) -> None:
url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
@@ -196,27 +196,27 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
request_data = {"aliases": [self.random_alias(alias_length)]}
channel = self.make_request(
channel = await self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, expected_code, channel.result)
def set_alias_via_directory(
async def set_alias_via_directory(
self, expected_code: HTTPStatus, alias_length: int = 5
) -> str:
alias = self.random_alias(alias_length)
url = "/_matrix/client/r0/directory/room/%s" % alias
request_data = {"room_id": self.room_id}
channel = self.make_request(
channel = await self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, expected_code, channel.result)
return alias
def test_invalid_alias(self) -> None:
async def test_invalid_alias(self) -> None:
alias = "#potato"
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/_matrix/client/r0/directory/room/{alias}",
access_token=self.user_tok,
@@ -230,18 +230,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
def random_alias(self, length: int) -> str:
return RoomAlias(random_string(length), self.hs.hostname).to_string()
def ensure_user_left_room(self) -> None:
self.ensure_membership("leave")
async def ensure_user_left_room(self) -> None:
await self.ensure_membership("leave")
def ensure_user_joined_room(self) -> None:
self.ensure_membership("join")
async def ensure_user_joined_room(self) -> None:
await self.ensure_membership("join")
def ensure_membership(self, membership: str) -> None:
async def ensure_membership(self, membership: str) -> None:
try:
if membership == "leave":
self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok)
await self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok)
if membership == "join":
self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok)
await self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok)
except AssertionError:
# We don't care whether the leave request didn't return a 200 (e.g.
# if the user isn't already in the room), because we only want to
+15 -15
View File
@@ -39,18 +39,18 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["enable_ephemeral_messages"] = True
self.hs = self.setup_test_homeserver(config=config)
self.hs = await self.setup_test_homeserver(config=config)
return self.hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = await self.helper.create_room_as(self.user_id)
def test_message_expiry_no_delay(self) -> None:
async def test_message_expiry_no_delay(self) -> None:
"""Tests that sending a message sent with a m.self_destruct_after field set to the
past results in that event being deleted right away.
"""
@@ -58,7 +58,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
# at 200ms, so 0 is in the past, and even if that wasn't the case and the clock
# is at 0ms the code path is the same if the event's expiry timestamp is the
# current timestamp.
res = self.helper.send_event(
res = await self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
@@ -70,16 +70,16 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
event_id = res["event_id"]
# Check that we can't retrieve the content of the event.
event_content = self.get_event(self.room_id, event_id)["content"]
event_content = (await self.get_event(self.room_id, event_id))["content"]
self.assertFalse(bool(event_content), event_content)
def test_message_expiry_delay(self) -> None:
async def test_message_expiry_delay(self) -> None:
"""Tests that sending a message with a m.self_destruct_after field set to the
future results in that event not being deleted right away, but advancing the
clock to after that expiry timestamp causes the event to be deleted.
"""
# Send a message in the room that'll expire in 1s.
res = self.helper.send_event(
res = await self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
@@ -91,22 +91,22 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
event_id = res["event_id"]
# Check that we can retrieve the content of the event before it has expired.
event_content = self.get_event(self.room_id, event_id)["content"]
event_content = (await self.get_event(self.room_id, event_id))["content"]
self.assertTrue(bool(event_content), event_content)
# Advance the clock to after the deletion.
self.reactor.advance(1)
# Advance the clock to after the deletion and let the expiry handler run.
await self.pump(1)
# Check that we can't retrieve the content of the event anymore.
event_content = self.get_event(self.room_id, event_id)["content"]
event_content = (await self.get_event(self.room_id, event_id))["content"]
self.assertFalse(bool(event_content), event_content)
def get_event(
async def get_event(
self, room_id: str, event_id: str, expected_code: int = HTTPStatus.OK
) -> JsonDict:
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
channel = self.make_request("GET", url)
channel = await self.make_request("GET", url)
self.assertEqual(channel.code, expected_code, channel.result)
+23 -23
View File
@@ -44,41 +44,41 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["enable_registration_captcha"] = False
config["enable_registration"] = True
config["auto_join_rooms"] = []
hs = self.setup_test_homeserver(config=config)
hs = await self.setup_test_homeserver(config=config)
hs.get_federation_handler = Mock() # type: ignore[method-assign]
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# register an account
self.user_id = self.register_user("sid1", "pass")
self.token = self.login(self.user_id, "pass")
self.user_id = await self.register_user("sid1", "pass")
self.token = await self.login(self.user_id, "pass")
# register a 2nd account
self.other_user = self.register_user("other2", "pass")
self.other_token = self.login(self.other_user, "pass")
self.other_user = await self.register_user("other2", "pass")
self.other_token = await self.login(self.other_user, "pass")
def test_stream_basic_permissions(self) -> None:
async def test_stream_basic_permissions(self) -> None:
# invalid token, expect 401
# note: this is in violation of the original v1 spec, which expected
# 403. However, since the v1 spec no longer exists and the v1
# implementation is now part of the r0 implementation, the newer
# behaviour is used instead to be consistent with the r0 spec.
# see issue https://github.com/matrix-org/synapse/issues/2602
channel = self.make_request(
channel = await self.make_request(
"GET", "/events?access_token=%s" % ("invalid" + self.token,)
)
self.assertEqual(channel.code, 401, msg=channel.result)
# valid token, expect content
channel = self.make_request(
channel = await self.make_request(
"GET", "/events?access_token=%s&timeout=0" % (self.token,)
)
self.assertEqual(channel.code, 200, msg=channel.result)
@@ -86,17 +86,17 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
self.assertTrue("start" in channel.json_body)
self.assertTrue("end" in channel.json_body)
def test_stream_room_permissions(self) -> None:
room_id = self.helper.create_room_as(self.other_user, tok=self.other_token)
self.helper.send(room_id, tok=self.other_token)
async def test_stream_room_permissions(self) -> None:
room_id = await self.helper.create_room_as(self.other_user, tok=self.other_token)
await self.helper.send(room_id, tok=self.other_token)
# invited to room (expect no content for room)
self.helper.invite(
await self.helper.invite(
room_id, src=self.other_user, targ=self.user_id, tok=self.other_token
)
# valid token, expect content
channel = self.make_request(
channel = await self.make_request(
"GET", "/events?access_token=%s&timeout=0" % (self.token,)
)
self.assertEqual(channel.code, 200, msg=channel.result)
@@ -117,7 +117,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
)
# joined room (expect all content for room)
self.helper.join(room=room_id, user=self.user_id, tok=self.token)
await self.helper.join(room=room_id, user=self.user_id, tok=self.token)
# left to room (expect no content for room)
@@ -146,18 +146,18 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# register an account
self.user_id = self.register_user("sid1", "pass")
self.token = self.login(self.user_id, "pass")
self.user_id = await self.register_user("sid1", "pass")
self.token = await self.login(self.user_id, "pass")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
self.room_id = await self.helper.create_room_as(self.user_id, tok=self.token)
def test_get_event_via_events(self) -> None:
resp = self.helper.send(self.room_id, tok=self.token)
async def test_get_event_via_events(self) -> None:
resp = await self.helper.send(self.room_id, tok=self.token)
event_id = resp["event_id"]
channel = self.make_request(
channel = await self.make_request(
"GET",
"/events/" + event_id,
access_token=self.token,
+64 -64
View File
@@ -43,18 +43,18 @@ class UserMutualRoomsTest(unittest.HomeserverTestCase):
mutual_rooms.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
return self.setup_test_homeserver(config=config)
return await self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
mutual_rooms.MUTUAL_ROOMS_BATCH_LIMIT = 10
def _get_mutual_rooms(
async def _get_mutual_rooms(
self, token: str, other_user: str, since_token: str | None = None
) -> FakeChannel:
return self.make_request(
return await self.make_request(
"GET",
"/_matrix/client/v1/mutual_rooms"
f"?user_id={quote(other_user)}"
@@ -62,183 +62,183 @@ class UserMutualRoomsTest(unittest.HomeserverTestCase):
access_token=token,
)
def test_shared_room_list_public(self) -> None:
async def test_shared_room_list_public(self) -> None:
"""
A room should show up in the shared list of rooms between two users
if it is public.
"""
self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True)
await self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True)
def test_shared_room_list_private(self) -> None:
async def test_shared_room_list_private(self) -> None:
"""
A room should show up in the shared list of rooms between two users
if it is private.
"""
self._check_mutual_rooms_with(
await self._check_mutual_rooms_with(
room_one_is_public=False, room_two_is_public=False
)
def test_shared_room_list_mixed(self) -> None:
async def test_shared_room_list_mixed(self) -> None:
"""
The shared room list between two users should contain both public and private
rooms.
"""
self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False)
await self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False)
def _check_mutual_rooms_with(
async def _check_mutual_rooms_with(
self, room_one_is_public: bool, room_two_is_public: bool
) -> None:
"""Checks that shared public or private rooms between two users appear in
their shared room lists
"""
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
u2 = self.register_user("user2", "pass")
u2_token = self.login(u2, "pass")
u1 = await self.register_user("user1", "pass")
u1_token = await self.login(u1, "pass")
u2 = await self.register_user("user2", "pass")
u2_token = await self.login(u2, "pass")
# Create a room. user1 invites user2, who joins
room_id_one = self.helper.create_room_as(
room_id_one = await self.helper.create_room_as(
u1, is_public=room_one_is_public, tok=u1_token
)
self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token)
self.helper.join(room_id_one, user=u2, tok=u2_token)
await self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token)
await self.helper.join(room_id_one, user=u2, tok=u2_token)
# Check shared rooms from user1's perspective.
# We should see the one room in common
channel = self._get_mutual_rooms(u1_token, u2)
channel = await self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 1)
self.assertEqual(channel.json_body["count"], 1)
self.assertEqual(channel.json_body["joined"][0], room_id_one)
# Create another room and invite user2 to it
room_id_two = self.helper.create_room_as(
room_id_two = await self.helper.create_room_as(
u1, is_public=room_two_is_public, tok=u1_token
)
self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token)
self.helper.join(room_id_two, user=u2, tok=u2_token)
await self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token)
await self.helper.join(room_id_two, user=u2, tok=u2_token)
# Check shared rooms again. We should now see both rooms.
channel = self._get_mutual_rooms(u1_token, u2)
channel = await self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 2)
self.assertEqual(channel.json_body["count"], 2)
for room_id_id in channel.json_body["joined"]:
self.assertIn(room_id_id, [room_id_one, room_id_two])
def _create_rooms_for_pagination_test(
async def _create_rooms_for_pagination_test(
self, count: int
) -> tuple[str, str, list[str]]:
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
u2 = self.register_user("user2", "pass")
u2_token = self.login(u2, "pass")
u1 = await self.register_user("user1", "pass")
u1_token = await self.login(u1, "pass")
u2 = await self.register_user("user2", "pass")
u2_token = await self.login(u2, "pass")
room_ids = []
for i in range(count):
room_id = self.helper.create_room_as(u1, is_public=i % 2 == 0, tok=u1_token)
self.helper.invite(room_id, src=u1, targ=u2, tok=u1_token)
self.helper.join(room_id, user=u2, tok=u2_token)
room_id = await self.helper.create_room_as(u1, is_public=i % 2 == 0, tok=u1_token)
await self.helper.invite(room_id, src=u1, targ=u2, tok=u1_token)
await self.helper.join(room_id, user=u2, tok=u2_token)
room_ids.append(room_id)
room_ids.sort()
return u1_token, u2, room_ids
def test_shared_room_list_pagination_two_pages(self) -> None:
u1_token, u2, room_ids = self._create_rooms_for_pagination_test(15)
async def test_shared_room_list_pagination_two_pages(self) -> None:
u1_token, u2, room_ids = await self._create_rooms_for_pagination_test(15)
channel = self._get_mutual_rooms(u1_token, u2)
channel = await self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(channel.json_body["joined"], room_ids[0:10])
self.assertEqual(channel.json_body["count"], 15)
self.assertIn("next_batch", channel.json_body)
channel = self._get_mutual_rooms(u1_token, u2, channel.json_body["next_batch"])
channel = await self._get_mutual_rooms(u1_token, u2, channel.json_body["next_batch"])
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(channel.json_body["joined"], room_ids[10:20])
self.assertEqual(channel.json_body["count"], 15)
self.assertNotIn("next_batch", channel.json_body)
def test_shared_room_list_pagination_one_page(self) -> None:
u1_token, u2, room_ids = self._create_rooms_for_pagination_test(10)
async def test_shared_room_list_pagination_one_page(self) -> None:
u1_token, u2, room_ids = await self._create_rooms_for_pagination_test(10)
channel = self._get_mutual_rooms(u1_token, u2)
channel = await self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(channel.json_body["joined"], room_ids)
self.assertEqual(channel.json_body["count"], 10)
self.assertNotIn("next_batch", channel.json_body)
def test_shared_room_list_pagination_invalid_token(self) -> None:
u1_token, u2, room_ids = self._create_rooms_for_pagination_test(10)
async def test_shared_room_list_pagination_invalid_token(self) -> None:
u1_token, u2, room_ids = await self._create_rooms_for_pagination_test(10)
channel = self._get_mutual_rooms(u1_token, u2, "!<>##faketoken")
channel = await self._get_mutual_rooms(u1_token, u2, "!<>##faketoken")
self.assertEqual(400, channel.code, channel.result)
self.assertEqual(
"M_INVALID_PARAM", channel.json_body["errcode"], channel.result
)
def test_shared_room_list_after_leave(self) -> None:
async def test_shared_room_list_after_leave(self) -> None:
"""
A room should no longer be considered shared if the other
user has left it.
"""
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
u2 = self.register_user("user2", "pass")
u2_token = self.login(u2, "pass")
u1 = await self.register_user("user1", "pass")
u1_token = await self.login(u1, "pass")
u2 = await self.register_user("user2", "pass")
u2_token = await self.login(u2, "pass")
room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
self.helper.join(room, user=u2, tok=u2_token)
room = await self.helper.create_room_as(u1, is_public=True, tok=u1_token)
await self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
await self.helper.join(room, user=u2, tok=u2_token)
# Assert user directory is not empty
channel = self._get_mutual_rooms(u1_token, u2)
channel = await self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 1)
self.assertEqual(channel.json_body["count"], 1)
self.assertEqual(channel.json_body["joined"][0], room)
self.helper.leave(room, user=u1, tok=u1_token)
await self.helper.leave(room, user=u1, tok=u1_token)
# Check user1's view of shared rooms with user2
channel = self._get_mutual_rooms(u1_token, u2)
channel = await self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 0)
self.assertEqual(channel.json_body["count"], 0)
# Check user2's view of shared rooms with user1
channel = self._get_mutual_rooms(u2_token, u1)
channel = await self._get_mutual_rooms(u2_token, u1)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 0)
self.assertEqual(channel.json_body["count"], 0)
def test_shared_room_list_nonexistent_user(self) -> None:
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
async def test_shared_room_list_nonexistent_user(self) -> None:
u1 = await self.register_user("user1", "pass")
u1_token = await self.login(u1, "pass")
# Check shared rooms from user1's perspective.
# We should see the one room in common
channel = self._get_mutual_rooms(u1_token, "@meow:example.com")
channel = await self._get_mutual_rooms(u1_token, "@meow:example.com")
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 0)
self.assertEqual(channel.json_body["count"], 0)
self.assertNotIn("next_batch", channel.json_body)
def test_shared_room_list_invalid_user(self) -> None:
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
async def test_shared_room_list_invalid_user(self) -> None:
u1 = await self.register_user("user1", "pass")
u1_token = await self.login(u1, "pass")
channel = self._get_mutual_rooms(u1_token, "@:example.com")
channel = await self._get_mutual_rooms(u1_token, "@:example.com")
self.assertEqual(400, channel.code, channel.result)
self.assertEqual(
"M_INVALID_PARAM", channel.json_body["errcode"], channel.result
)
channel = self._get_mutual_rooms(u1_token, "@" + "a" * 255 + ":example.com")
channel = await self._get_mutual_rooms(u1_token, "@" + "a" * 255 + ":example.com")
self.assertEqual(400, channel.code, channel.result)
self.assertEqual(
"M_INVALID_PARAM", channel.json_body["errcode"], channel.result
)
channel = self._get_mutual_rooms(u1_token, "@🐈️:example.com")
channel = await self._get_mutual_rooms(u1_token, "@🐈️:example.com")
self.assertEqual(400, channel.code, channel.result)
self.assertEqual(
"M_INVALID_PARAM", channel.json_body["errcode"], channel.result
+27 -27
View File
@@ -39,7 +39,7 @@ class HTTPPusherTests(HomeserverTestCase):
notifications.register_servlets,
]
def prepare(
async def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.store = homeserver.get_datastores().main
@@ -48,32 +48,32 @@ class HTTPPusherTests(HomeserverTestCase):
self.sync_handler = homeserver.get_sync_handler()
self.auth_handler = homeserver.get_auth_handler()
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
self.user_id = await self.register_user("user", "pass")
self.access_token = await self.login("user", "pass")
self.other_user_id = await self.register_user("otheruser", "pass")
self.other_access_token = await self.login("otheruser", "pass")
# Create a room
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.room_id = await self.helper.create_room_as(self.user_id, tok=self.access_token)
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"])
fed_transport_client.send_transaction = AsyncMock(return_value={})
return self.setup_test_homeserver(
return await self.setup_test_homeserver(
federation_transport_client=fed_transport_client,
)
def test_notify_for_local_invites(self) -> None:
async def test_notify_for_local_invites(self) -> None:
"""
Local users will get notified for invites
"""
# Check we start with no pushes
self._request_notifications(from_token=None, limit=1, expected_count=0)
await self._request_notifications(from_token=None, limit=1, expected_count=0)
# Send an invite
self.helper.invite(
await self.helper.invite(
room=self.room_id,
src=self.user_id,
targ=self.other_user_id,
@@ -81,7 +81,7 @@ class HTTPPusherTests(HomeserverTestCase):
)
# We should have a notification now
channel = self.make_request(
channel = await self.make_request(
"GET",
"/notifications",
access_token=self.other_access_token,
@@ -94,26 +94,26 @@ class HTTPPusherTests(HomeserverTestCase):
channel.json_body,
)
def test_pagination_of_notifications(self) -> None:
async def test_pagination_of_notifications(self) -> None:
"""
Check that pagination of notifications works.
"""
# Check we start with no pushes
self._request_notifications(from_token=None, limit=1, expected_count=0)
await self._request_notifications(from_token=None, limit=1, expected_count=0)
# Send an invite and have the other user join the room.
self.helper.invite(
await self.helper.invite(
room=self.room_id,
src=self.user_id,
targ=self.other_user_id,
tok=self.access_token,
)
self.helper.join(self.room_id, self.other_user_id, tok=self.other_access_token)
await self.helper.join(self.room_id, self.other_user_id, tok=self.other_access_token)
# Send 5 messages in the room and note down their event IDs.
sent_event_ids = []
for _ in range(5):
resp = self.helper.send_event(
resp = await self.helper.send_event(
self.room_id,
"m.room.message",
{"body": "honk", "msgtype": "m.text"},
@@ -127,7 +127,7 @@ class HTTPPusherTests(HomeserverTestCase):
sent_event_ids.reverse()
# We should have a few notifications now. Let's try and fetch the first 2.
notification_event_ids, _ = self._request_notifications(
notification_event_ids, _ = await self._request_notifications(
from_token=None, limit=2, expected_count=2
)
@@ -136,7 +136,7 @@ class HTTPPusherTests(HomeserverTestCase):
# Try requesting again without a 'from' query parameter. We should get the
# same two notifications back.
notification_event_ids, next_token = self._request_notifications(
notification_event_ids, next_token = await self._request_notifications(
from_token=None, limit=2, expected_count=2
)
self.assertEqual(notification_event_ids, sent_event_ids[:2])
@@ -146,14 +146,14 @@ class HTTPPusherTests(HomeserverTestCase):
#
# We need to use the "next_token" from the response as the "from"
# query parameter in the next request in order to paginate.
notification_event_ids, next_token = self._request_notifications(
notification_event_ids, next_token = await self._request_notifications(
from_token=next_token, limit=5, expected_count=4
)
# Ensure we chop off the invite on the end.
notification_event_ids = notification_event_ids[:-1]
self.assertEqual(notification_event_ids, sent_event_ids[2:])
def _request_notifications(
async def _request_notifications(
self, from_token: str | None, limit: int, expected_count: int
) -> tuple[list[str], str]:
"""
@@ -175,7 +175,7 @@ class HTTPPusherTests(HomeserverTestCase):
if from_token is not None:
path += f"&from={from_token}"
channel = self.make_request(
channel = await self.make_request(
"GET",
path,
access_token=self.other_access_token,
@@ -194,12 +194,12 @@ class HTTPPusherTests(HomeserverTestCase):
return event_ids, next_token
def test_parameters(self) -> None:
async def test_parameters(self) -> None:
"""
Test that appropriate errors are returned when query parameters are malformed.
"""
# Test that no parameters are required.
channel = self.make_request(
channel = await self.make_request(
"GET",
"/notifications",
access_token=self.other_access_token,
@@ -207,7 +207,7 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(channel.code, 200)
# Test that limit cannot be negative
channel = self.make_request(
channel = await self.make_request(
"GET",
"/notifications?limit=-1",
access_token=self.other_access_token,
@@ -215,7 +215,7 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(channel.code, 400)
# Test that the 'limit' parameter must be an integer.
channel = self.make_request(
channel = await self.make_request(
"GET",
"/notifications?limit=foobar",
access_token=self.other_access_token,
@@ -223,7 +223,7 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(channel.code, 400)
# Test that the 'from' parameter must be an integer.
channel = self.make_request(
channel = await self.make_request(
"GET",
"/notifications?from=osborne",
access_token=self.other_access_token,
+42 -42
View File
@@ -42,33 +42,33 @@ class PowerLevelsTestCase(HomeserverTestCase):
sync.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
return self.setup_test_homeserver(config=config)
return await self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# register a room admin, moderator and regular user
self.admin_user_id = self.register_user("admin", "pass")
self.admin_access_token = self.login("admin", "pass")
self.mod_user_id = self.register_user("mod", "pass")
self.mod_access_token = self.login("mod", "pass")
self.user_user_id = self.register_user("user", "pass")
self.user_access_token = self.login("user", "pass")
self.admin_user_id = await self.register_user("admin", "pass")
self.admin_access_token = await self.login("admin", "pass")
self.mod_user_id = await self.register_user("mod", "pass")
self.mod_access_token = await self.login("mod", "pass")
self.user_user_id = await self.register_user("user", "pass")
self.user_access_token = await self.login("user", "pass")
# Create a room
self.room_id = self.helper.create_room_as(
self.room_id = await self.helper.create_room_as(
self.admin_user_id, tok=self.admin_access_token
)
# Invite the other users
self.helper.invite(
await self.helper.invite(
room=self.room_id,
src=self.admin_user_id,
tok=self.admin_access_token,
targ=self.mod_user_id,
)
self.helper.invite(
await self.helper.invite(
room=self.room_id,
src=self.admin_user_id,
tok=self.admin_access_token,
@@ -76,15 +76,15 @@ class PowerLevelsTestCase(HomeserverTestCase):
)
# Make the other users join the room
self.helper.join(
await self.helper.join(
room=self.room_id, user=self.mod_user_id, tok=self.mod_access_token
)
self.helper.join(
await self.helper.join(
room=self.room_id, user=self.user_user_id, tok=self.user_access_token
)
# Mod the mod
room_power_levels = self.helper.get_state(
room_power_levels = await self.helper.get_state(
self.room_id,
"m.room.power_levels",
tok=self.admin_access_token,
@@ -93,16 +93,16 @@ class PowerLevelsTestCase(HomeserverTestCase):
# Update existing power levels with mod at PL50
room_power_levels["users"].update({self.mod_user_id: 50})
self.helper.send_state(
await self.helper.send_state(
self.room_id,
"m.room.power_levels",
room_power_levels,
tok=self.admin_access_token,
)
def test_non_admins_cannot_enable_room_encryption(self) -> None:
async def test_non_admins_cannot_enable_room_encryption(self) -> None:
# have the mod try to enable room encryption
self.helper.send_state(
await self.helper.send_state(
self.room_id,
"m.room.encryption",
{"algorithm": "m.megolm.v1.aes-sha2"},
@@ -111,7 +111,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
)
# have the user try to enable room encryption
self.helper.send_state(
await self.helper.send_state(
self.room_id,
"m.room.encryption",
{"algorithm": "m.megolm.v1.aes-sha2"},
@@ -119,9 +119,9 @@ class PowerLevelsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.FORBIDDEN, # expect failure
)
def test_non_admins_cannot_send_server_acl(self) -> None:
async def test_non_admins_cannot_send_server_acl(self) -> None:
# have the mod try to send a server ACL
self.helper.send_state(
await self.helper.send_state(
self.room_id,
"m.room.server_acl",
{
@@ -134,7 +134,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
)
# have the user try to send a server ACL
self.helper.send_state(
await self.helper.send_state(
self.room_id,
"m.room.server_acl",
{
@@ -146,14 +146,14 @@ class PowerLevelsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.FORBIDDEN, # expect failure
)
def test_non_admins_cannot_tombstone_room(self) -> None:
async def test_non_admins_cannot_tombstone_room(self) -> None:
# Create another room that will serve as our "upgraded room"
self.upgraded_room_id = self.helper.create_room_as(
self.upgraded_room_id = await self.helper.create_room_as(
self.admin_user_id, tok=self.admin_access_token
)
# have the mod try to send a tombstone event
self.helper.send_state(
await self.helper.send_state(
self.room_id,
"m.room.tombstone",
{
@@ -165,7 +165,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
)
# have the user try to send a tombstone event
self.helper.send_state(
await self.helper.send_state(
self.room_id,
"m.room.tombstone",
{
@@ -176,9 +176,9 @@ class PowerLevelsTestCase(HomeserverTestCase):
expect_code=403, # expect failure
)
def test_admins_can_enable_room_encryption(self) -> None:
async def test_admins_can_enable_room_encryption(self) -> None:
# have the admin try to enable room encryption
self.helper.send_state(
await self.helper.send_state(
self.room_id,
"m.room.encryption",
{"algorithm": "m.megolm.v1.aes-sha2"},
@@ -186,9 +186,9 @@ class PowerLevelsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.OK, # expect success
)
def test_admins_can_send_server_acl(self) -> None:
async def test_admins_can_send_server_acl(self) -> None:
# have the admin try to send a server ACL
self.helper.send_state(
await self.helper.send_state(
self.room_id,
"m.room.server_acl",
{
@@ -200,14 +200,14 @@ class PowerLevelsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.OK, # expect success
)
def test_admins_can_tombstone_room(self) -> None:
async def test_admins_can_tombstone_room(self) -> None:
# Create another room that will serve as our "upgraded room"
self.upgraded_room_id = self.helper.create_room_as(
self.upgraded_room_id = await self.helper.create_room_as(
self.admin_user_id, tok=self.admin_access_token
)
# have the admin try to send a tombstone event
self.helper.send_state(
await self.helper.send_state(
self.room_id,
"m.room.tombstone",
{
@@ -218,8 +218,8 @@ class PowerLevelsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.OK, # expect success
)
def test_cannot_set_string_power_levels(self) -> None:
room_power_levels = self.helper.get_state(
async def test_cannot_set_string_power_levels(self) -> None:
room_power_levels = await self.helper.get_state(
self.room_id,
"m.room.power_levels",
tok=self.admin_access_token,
@@ -228,7 +228,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
# Update existing power levels with user at PL "0"
room_power_levels["users"].update({self.user_user_id: "0"})
body = self.helper.send_state(
body = await self.helper.send_state(
self.room_id,
"m.room.power_levels",
room_power_levels,
@@ -242,8 +242,8 @@ class PowerLevelsTestCase(HomeserverTestCase):
body,
)
def test_cannot_set_unsafe_large_power_levels(self) -> None:
room_power_levels = self.helper.get_state(
async def test_cannot_set_unsafe_large_power_levels(self) -> None:
room_power_levels = await self.helper.get_state(
self.room_id,
"m.room.power_levels",
tok=self.admin_access_token,
@@ -254,7 +254,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
{self.user_user_id: CANONICALJSON_MAX_INT + 1}
)
body = self.helper.send_state(
body = await self.helper.send_state(
self.room_id,
"m.room.power_levels",
room_power_levels,
@@ -268,8 +268,8 @@ class PowerLevelsTestCase(HomeserverTestCase):
body,
)
def test_cannot_set_unsafe_small_power_levels(self) -> None:
room_power_levels = self.helper.get_state(
async def test_cannot_set_unsafe_small_power_levels(self) -> None:
room_power_levels = await self.helper.get_state(
self.room_id,
"m.room.power_levels",
tok=self.admin_access_token,
@@ -280,7 +280,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
{self.user_user_id: CANONICALJSON_MIN_INT - 1}
)
body = self.helper.send_state(
body = await self.helper.send_state(
self.room_id,
"m.room.power_levels",
room_power_levels,
+14 -14
View File
@@ -40,11 +40,11 @@ class PresenceTestCase(unittest.HomeserverTestCase):
user = UserID.from_string(user_id)
servlets = [presence.register_servlets]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.presence_handler = Mock(spec=PresenceHandler)
self.presence_handler.set_state = AsyncMock(return_value=None)
hs = self.setup_test_homeserver(
hs = await self.setup_test_homeserver(
"red",
federation_client=Mock(),
presence_handler=self.presence_handler,
@@ -52,7 +52,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
return hs
def test_put_presence(self) -> None:
async def test_put_presence(self) -> None:
"""
PUT to the status endpoint with use_presence enabled will call
set_state on the presence handler.
@@ -60,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
self.hs.config.server.presence_enabled = True
body = {"presence": "here", "status_msg": "beep boop"}
channel = self.make_request(
channel = await self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
@@ -68,14 +68,14 @@ class PresenceTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.presence_handler.set_state.call_count, 1)
@unittest.override_config({"use_presence": False})
def test_put_presence_disabled(self) -> None:
async def test_put_presence_disabled(self) -> None:
"""
PUT to the status endpoint with presence disabled will NOT call
set_state on the presence handler.
"""
body = {"presence": "here", "status_msg": "beep boop"}
channel = self.make_request(
channel = await self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
@@ -83,14 +83,14 @@ class PresenceTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.presence_handler.set_state.call_count, 0)
@unittest.override_config({"presence": {"enabled": "untracked"}})
def test_put_presence_untracked(self) -> None:
async def test_put_presence_untracked(self) -> None:
"""
PUT to the status endpoint with presence untracked will NOT call
set_state on the presence handler.
"""
body = {"presence": "here", "status_msg": "beep boop"}
channel = self.make_request(
channel = await self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
@@ -100,21 +100,21 @@ class PresenceTestCase(unittest.HomeserverTestCase):
@override_config(
{"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
)
def test_put_presence_over_ratelimit(self) -> None:
async def test_put_presence_over_ratelimit(self) -> None:
"""
Multiple PUTs to the status endpoint without sufficient delay will be rate limited.
"""
self.hs.config.server.presence_enabled = True
body = {"presence": "here", "status_msg": "beep boop"}
channel = self.make_request(
channel = await self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
self.assertEqual(channel.code, HTTPStatus.OK)
body = {"presence": "here", "status_msg": "beep boop"}
channel = self.make_request(
channel = await self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
@@ -124,14 +124,14 @@ class PresenceTestCase(unittest.HomeserverTestCase):
@override_config(
{"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
)
def test_put_presence_within_ratelimit(self) -> None:
async def test_put_presence_within_ratelimit(self) -> None:
"""
Multiple PUTs to the status endpoint with sufficient delay should all call set_state.
"""
self.hs.config.server.presence_enabled = True
body = {"presence": "here", "status_msg": "beep boop"}
channel = self.make_request(
channel = await self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
@@ -141,7 +141,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
self.reactor.advance(30)
body = {"presence": "here", "status_msg": "beep boop"}
channel = self.make_request(
channel = await self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
+184 -165
View File
@@ -50,27 +50,27 @@ class ProfileTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver()
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = await self.setup_test_homeserver()
return self.hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.owner = self.register_user("owner", "pass")
self.owner_tok = self.login("owner", "pass")
self.other = self.register_user("other", "pass", displayname="Bob")
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.owner = await self.register_user("owner", "pass")
self.owner_tok = await self.login("owner", "pass")
self.other = await self.register_user("other", "pass", displayname="Bob")
def test_get_displayname(self) -> None:
res = self._get_displayname()
async def test_get_displayname(self) -> None:
res = await self._get_displayname()
self.assertEqual(res, "owner")
def test_get_displayname_rejects_bad_username(self) -> None:
channel = self.make_request(
async def test_get_displayname_rejects_bad_username(self) -> None:
channel = await self.make_request(
"GET", f"/profile/{urllib.parse.quote('@alice:')}/displayname"
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
def test_set_displayname(self) -> None:
channel = self.make_request(
async def test_set_displayname(self) -> None:
channel = await self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
content={"displayname": "test"},
@@ -78,11 +78,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
res = self._get_displayname()
res = await self._get_displayname()
self.assertEqual(res, "test")
def test_set_displayname_with_extra_spaces(self) -> None:
channel = self.make_request(
async def test_set_displayname_with_extra_spaces(self) -> None:
channel = await self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
content={"displayname": " test "},
@@ -90,20 +90,20 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
res = self._get_displayname()
res = await self._get_displayname()
self.assertEqual(res, "test")
def test_set_displayname_noauth(self) -> None:
channel = self.make_request(
async def test_set_displayname_noauth(self) -> None:
channel = await self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
content={"displayname": "test"},
)
self.assertEqual(channel.code, 401, channel.result)
def test_set_displayname_too_long(self) -> None:
async def test_set_displayname_too_long(self) -> None:
"""Attempts to set a stupid displayname should get a 400"""
channel = self.make_request(
channel = await self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
content={"displayname": "test" * 100},
@@ -111,15 +111,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
res = self._get_displayname()
res = await self._get_displayname()
self.assertEqual(res, "owner")
def test_get_displayname_other(self) -> None:
res = self._get_displayname(self.other)
async def test_get_displayname_other(self) -> None:
res = await self._get_displayname(self.other)
self.assertEqual(res, "Bob")
def test_set_displayname_other(self) -> None:
channel = self.make_request(
async def test_set_displayname_other(self) -> None:
channel = await self.make_request(
"PUT",
"/profile/%s/displayname" % (self.other,),
content={"displayname": "test"},
@@ -127,12 +127,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
def test_get_avatar_url(self) -> None:
res = self._get_avatar_url()
async def test_get_avatar_url(self) -> None:
res = await self._get_avatar_url()
self.assertIsNone(res)
def test_set_avatar_url(self) -> None:
channel = self.make_request(
async def test_set_avatar_url(self) -> None:
channel = await self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.owner,),
content={"avatar_url": "http://my.server/pic.gif"},
@@ -140,20 +140,20 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
res = self._get_avatar_url()
res = await self._get_avatar_url()
self.assertEqual(res, "http://my.server/pic.gif")
def test_set_avatar_url_noauth(self) -> None:
channel = self.make_request(
async def test_set_avatar_url_noauth(self) -> None:
channel = await self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.owner,),
content={"avatar_url": "http://my.server/pic.gif"},
)
self.assertEqual(channel.code, 401, channel.result)
def test_set_avatar_url_too_long(self) -> None:
async def test_set_avatar_url_too_long(self) -> None:
"""Attempts to set a stupid avatar_url should get a 400"""
channel = self.make_request(
channel = await self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.owner,),
content={"avatar_url": "http://my.server/pic.gif" * 100},
@@ -161,15 +161,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
res = self._get_avatar_url()
res = await self._get_avatar_url()
self.assertIsNone(res)
def test_get_avatar_url_other(self) -> None:
res = self._get_avatar_url(self.other)
async def test_get_avatar_url_other(self) -> None:
res = await self._get_avatar_url(self.other)
self.assertIsNone(res)
def test_set_avatar_url_other(self) -> None:
channel = self.make_request(
async def test_set_avatar_url_other(self) -> None:
channel = await self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.other,),
content={"avatar_url": "http://my.server/pic.gif"},
@@ -177,8 +177,8 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
def _get_displayname(self, name: str | None = None) -> str | None:
channel = self.make_request(
async def _get_displayname(self, name: str | None = None) -> str | None:
channel = await self.make_request(
"GET", "/profile/%s/displayname" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result)
@@ -187,8 +187,8 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# https://github.com/matrix-org/synapse/issues/13137.
return channel.json_body.get("displayname")
def _get_avatar_url(self, name: str | None = None) -> str | None:
channel = self.make_request(
async def _get_avatar_url(self, name: str | None = None) -> str | None:
channel = await self.make_request(
"GET", "/profile/%s/avatar_url" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result)
@@ -198,18 +198,18 @@ class ProfileTestCase(unittest.HomeserverTestCase):
return channel.json_body.get("avatar_url")
@unittest.override_config({"max_avatar_size": 50})
def test_avatar_size_limit_global(self) -> None:
async def test_avatar_size_limit_global(self) -> None:
"""Tests that the maximum size limit for avatars is enforced when updating a
global profile.
"""
self._setup_local_files(
await self._setup_local_files(
{
"small": {"size": 40},
"big": {"size": 60},
}
)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/profile/{self.owner}/avatar_url",
content={"avatar_url": "mxc://test/big"},
@@ -220,7 +220,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/profile/{self.owner}/avatar_url",
content={"avatar_url": "mxc://test/small"},
@@ -229,20 +229,20 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
@unittest.override_config({"max_avatar_size": 50})
def test_avatar_size_limit_per_room(self) -> None:
async def test_avatar_size_limit_per_room(self) -> None:
"""Tests that the maximum size limit for avatars is enforced when updating a
per-room profile.
"""
self._setup_local_files(
await self._setup_local_files(
{
"small": {"size": 40},
"big": {"size": 60},
}
)
room_id = self.helper.create_room_as(tok=self.owner_tok)
room_id = await self.helper.create_room_as(tok=self.owner_tok)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
content={"membership": "join", "avatar_url": "mxc://test/big"},
@@ -253,7 +253,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
content={"membership": "join", "avatar_url": "mxc://test/small"},
@@ -262,18 +262,18 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
def test_avatar_allowed_mime_type_global(self) -> None:
async def test_avatar_allowed_mime_type_global(self) -> None:
"""Tests that the MIME type whitelist for avatars is enforced when updating a
global profile.
"""
self._setup_local_files(
await self._setup_local_files(
{
"good": {"mimetype": "image/png"},
"bad": {"mimetype": "application/octet-stream"},
}
)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/profile/{self.owner}/avatar_url",
content={"avatar_url": "mxc://test/bad"},
@@ -284,7 +284,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/profile/{self.owner}/avatar_url",
content={"avatar_url": "mxc://test/good"},
@@ -293,20 +293,20 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
def test_avatar_allowed_mime_type_per_room(self) -> None:
async def test_avatar_allowed_mime_type_per_room(self) -> None:
"""Tests that the MIME type whitelist for avatars is enforced when updating a
per-room profile.
"""
self._setup_local_files(
await self._setup_local_files(
{
"good": {"mimetype": "image/png"},
"bad": {"mimetype": "application/octet-stream"},
}
)
room_id = self.helper.create_room_as(tok=self.owner_tok)
room_id = await self.helper.create_room_as(tok=self.owner_tok)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
content={"membership": "join", "avatar_url": "mxc://test/bad"},
@@ -317,7 +317,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
content={"membership": "join", "avatar_url": "mxc://test/good"},
@@ -328,12 +328,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"experimental_features": {"msc4069_profile_inhibit_propagation": True}}
)
def test_msc4069_inhibit_propagation(self) -> None:
async def test_msc4069_inhibit_propagation(self) -> None:
"""Tests to ensure profile update propagation can be inhibited."""
for prop in ["avatar_url", "displayname"]:
room_id = self.helper.create_room_as(tok=self.owner_tok)
room_id = await self.helper.create_room_as(tok=self.owner_tok)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
content={"membership": "join", prop: "mxc://my.server/existing"},
@@ -341,7 +341,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/profile/{self.owner}/{prop}?org.matrix.msc4069.propagate=false",
content={prop: "http://my.server/pic.gif"},
@@ -350,13 +350,13 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
res = (
self._get_avatar_url()
await self._get_avatar_url()
if prop == "avatar_url"
else self._get_displayname()
else await self._get_displayname()
)
self.assertEqual(res, "http://my.server/pic.gif")
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
access_token=self.owner_tok,
@@ -364,14 +364,14 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get(prop), "mxc://my.server/existing")
def test_msc4069_inhibit_propagation_disabled(self) -> None:
async def test_msc4069_inhibit_propagation_disabled(self) -> None:
"""Tests to ensure profile update propagation inhibit flags are ignored when the
experimental flag is not enabled.
"""
for prop in ["avatar_url", "displayname"]:
room_id = self.helper.create_room_as(tok=self.owner_tok)
room_id = await self.helper.create_room_as(tok=self.owner_tok)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
content={"membership": "join", prop: "mxc://my.server/existing"},
@@ -379,7 +379,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/profile/{self.owner}/{prop}?org.matrix.msc4069.propagate=false",
content={prop: "http://my.server/pic.gif"},
@@ -387,14 +387,16 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
await self._wait_for_profile_propagation()
res = (
self._get_avatar_url()
await self._get_avatar_url()
if prop == "avatar_url"
else self._get_displayname()
else await self._get_displayname()
)
self.assertEqual(res, "http://my.server/pic.gif")
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
access_token=self.owner_tok,
@@ -405,12 +407,25 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# isn't enabled.
self.assertEqual(channel.json_body.get(prop), "http://my.server/pic.gif")
def test_msc4069_inhibit_propagation_default(self) -> None:
async def _wait_for_profile_propagation(self) -> None:
"""Wait for the background task that propagates profile changes to rooms."""
import asyncio
from synapse.util.task_scheduler import TaskStatus
for _ in range(50):
tasks = await self.hs.get_task_scheduler().get_tasks(
actions=["update_join_states"],
statuses=[TaskStatus.ACTIVE, TaskStatus.SCHEDULED],
)
if not tasks:
break
await asyncio.sleep(0.05)
async def test_msc4069_inhibit_propagation_default(self) -> None:
"""Tests to ensure profile update propagation happens by default."""
for prop in ["avatar_url", "displayname"]:
room_id = self.helper.create_room_as(tok=self.owner_tok)
room_id = await self.helper.create_room_as(tok=self.owner_tok)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
content={"membership": "join", prop: "mxc://my.server/existing"},
@@ -418,7 +433,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/profile/{self.owner}/{prop}",
content={prop: "http://my.server/pic.gif"},
@@ -426,14 +441,16 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
await self._wait_for_profile_propagation()
res = (
self._get_avatar_url()
await self._get_avatar_url()
if prop == "avatar_url"
else self._get_displayname()
else await self._get_displayname()
)
self.assertEqual(res, "http://my.server/pic.gif")
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
access_token=self.owner_tok,
@@ -447,12 +464,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"experimental_features": {"msc4069_profile_inhibit_propagation": True}}
)
def test_msc4069_inhibit_propagation_like_default(self) -> None:
async def test_msc4069_inhibit_propagation_like_default(self) -> None:
"""Tests to ensure clients can request explicit profile propagation."""
for prop in ["avatar_url", "displayname"]:
room_id = self.helper.create_room_as(tok=self.owner_tok)
room_id = await self.helper.create_room_as(tok=self.owner_tok)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
content={"membership": "join", prop: "mxc://my.server/existing"},
@@ -460,7 +477,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/profile/{self.owner}/{prop}?org.matrix.msc4069.propagate=true",
content={prop: "http://my.server/pic.gif"},
@@ -468,14 +485,16 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
await self._wait_for_profile_propagation()
res = (
self._get_avatar_url()
await self._get_avatar_url()
if prop == "avatar_url"
else self._get_displayname()
else await self._get_displayname()
)
self.assertEqual(res, "http://my.server/pic.gif")
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
access_token=self.owner_tok,
@@ -485,32 +504,32 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# The client requested ?propagate=true, so it should have happened.
self.assertEqual(channel.json_body.get(prop), "http://my.server/pic.gif")
def test_get_missing_custom_field(self) -> None:
channel = self.make_request(
async def test_get_missing_custom_field(self) -> None:
channel = await self.make_request(
"GET",
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_get_missing_custom_field_invalid_field_name(self) -> None:
channel = self.make_request(
async def test_get_missing_custom_field_invalid_field_name(self) -> None:
channel = await self.make_request(
"GET",
f"/_matrix/client/v3/profile/{self.owner}/[custom_field]",
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_get_custom_field_rejects_bad_username(self) -> None:
channel = self.make_request(
async def test_get_custom_field_rejects_bad_username(self) -> None:
channel = await self.make_request(
"GET",
f"/_matrix/client/v3/profile/{urllib.parse.quote('@alice:')}/custom_field",
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_set_custom_field(self) -> None:
channel = self.make_request(
async def test_set_custom_field(self) -> None:
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
content={"custom_field": "test"},
@@ -518,7 +537,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
)
@@ -526,7 +545,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body, {"custom_field": "test"})
# Overwriting the field should work.
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
content={"custom_field": "new_Value"},
@@ -534,7 +553,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
)
@@ -542,7 +561,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body, {"custom_field": "new_Value"})
# Deleting the field should work.
channel = self.make_request(
channel = await self.make_request(
"DELETE",
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
content={},
@@ -550,14 +569,14 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_non_string(self) -> None:
async def test_non_string(self) -> None:
"""Non-string fields are supported for custom fields."""
fields = {
"bool_field": True,
@@ -568,7 +587,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
}
for key, value in fields.items():
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.owner}/{key}",
content={key: value},
@@ -576,7 +595,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/_matrix/client/v3/profile/{self.owner}",
)
@@ -585,15 +604,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# Check getting individual fields works.
for key, value in fields.items():
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/_matrix/client/v3/profile/{self.owner}/{key}",
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.json_body, {key: value})
def test_set_custom_field_noauth(self) -> None:
channel = self.make_request(
async def test_set_custom_field_noauth(self) -> None:
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
content={"custom_field": "test"},
@@ -601,12 +620,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_TOKEN)
def test_set_custom_field_size(self) -> None:
async def test_set_custom_field_size(self) -> None:
"""
Attempts to set a custom field name that is too long should get a 400 error.
"""
# Key is missing.
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.owner}/",
content={"": "test"},
@@ -617,7 +636,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# Single key is too large.
key = "c" * 500
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.owner}/{key}",
content={key: "test"},
@@ -626,7 +645,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
channel = self.make_request(
channel = await self.make_request(
"DELETE",
f"/_matrix/client/v3/profile/{self.owner}/{key}",
content={key: "test"},
@@ -636,7 +655,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
# Key doesn't match body.
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
content={"diff_key": "test"},
@@ -645,7 +664,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
def test_set_custom_field_profile_too_long(self) -> None:
async def test_set_custom_field_profile_too_long(self) -> None:
"""
Attempts to set a custom field that would push the overall profile too large.
"""
@@ -665,7 +684,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# 2 braces, 1 comma
# 3 + 21 + 65498 = 65522 < 65536.
key = "a"
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.owner}/{key}",
content={key: "a" * 65498},
@@ -674,7 +693,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
# Get the entire profile.
channel = self.make_request(
channel = await self.make_request(
"GET",
f"/_matrix/client/v3/profile/{self.owner}",
access_token=self.owner_tok,
@@ -693,7 +712,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# The next one should fail, note the value has a (JSON) length of 2.
key = "b"
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.owner}/{key}",
content={key: "1" + "a" * ADDITIONAL_CHARS},
@@ -703,7 +722,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
# Setting an avatar or (longer) display name should not work.
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/profile/{self.owner}/displayname",
content={"displayname": "owner12345678" + "a" * ADDITIONAL_CHARS},
@@ -712,7 +731,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/profile/{self.owner}/avatar_url",
content={"avatar_url": "mxc://foo/bar"},
@@ -723,7 +742,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# Removing a single byte should work.
key = "b"
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.owner}/{key}",
content={key: "" + "a" * ADDITIONAL_CHARS},
@@ -733,7 +752,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# Finally, setting a field that already exists to a value that is <= in length should work.
key = "a"
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.owner}/{key}",
content={key: ""},
@@ -743,8 +762,8 @@ class ProfileTestCase(unittest.HomeserverTestCase):
finally:
sql_logger.disabled = sql_logger_was_disabled
def test_set_custom_field_displayname(self) -> None:
channel = self.make_request(
async def test_set_custom_field_displayname(self) -> None:
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.owner}/displayname",
content={"displayname": "test"},
@@ -752,11 +771,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
displayname = self._get_displayname()
displayname = await self._get_displayname()
self.assertEqual(displayname, "test")
def test_set_custom_field_avatar_url(self) -> None:
channel = self.make_request(
async def test_set_custom_field_avatar_url(self) -> None:
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.owner}/avatar_url",
content={"avatar_url": "mxc://test/good"},
@@ -764,12 +783,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
avatar_url = self._get_avatar_url()
avatar_url = await self._get_avatar_url()
self.assertEqual(avatar_url, "mxc://test/good")
def test_set_custom_field_other(self) -> None:
async def test_set_custom_field_other(self) -> None:
"""Setting someone else's profile field should fail"""
channel = self.make_request(
channel = await self.make_request(
"PUT",
f"/_matrix/client/v3/profile/{self.other}/custom_field",
content={"custom_field": "test"},
@@ -778,7 +797,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 403, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def _setup_local_files(self, names_and_props: dict[str, dict[str, Any]]) -> None:
async def _setup_local_files(self, names_and_props: dict[str, dict[str, Any]]) -> None:
"""Stores metadata about files in the database.
Args:
@@ -790,7 +809,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastores().main
for name, props in names_and_props.items():
self.get_success(
await self.get_success(
store.store_local_media(
media_id=name,
media_type=props.get("mimetype", "image/png"),
@@ -810,68 +829,68 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["require_auth_for_profile_requests"] = True
config["limit_profile_requests_to_users_who_share_rooms"] = True
self.hs = self.setup_test_homeserver(config=config)
self.hs = await self.setup_test_homeserver(config=config)
return self.hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# User owning the requested profile.
self.owner = self.register_user("owner", "pass")
self.owner_tok = self.login("owner", "pass")
self.owner = await self.register_user("owner", "pass")
self.owner_tok = await self.login("owner", "pass")
self.profile_url = "/profile/%s" % (self.owner)
# User requesting the profile.
self.requester = self.register_user("requester", "pass")
self.requester_tok = self.login("requester", "pass")
self.requester = await self.register_user("requester", "pass")
self.requester_tok = await self.login("requester", "pass")
self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok)
self.room_id = await self.helper.create_room_as(self.owner, tok=self.owner_tok)
def test_no_auth(self) -> None:
self.try_fetch_profile(401)
async def test_no_auth(self) -> None:
await self.try_fetch_profile(401)
def test_not_in_shared_room(self) -> None:
self.ensure_requester_left_room()
async def test_not_in_shared_room(self) -> None:
await self.ensure_requester_left_room()
self.try_fetch_profile(403, access_token=self.requester_tok)
await self.try_fetch_profile(403, access_token=self.requester_tok)
def test_in_shared_room(self) -> None:
self.ensure_requester_left_room()
async def test_in_shared_room(self) -> None:
await self.ensure_requester_left_room()
self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok)
await self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok)
self.try_fetch_profile(200, self.requester_tok)
await self.try_fetch_profile(200, self.requester_tok)
def try_fetch_profile(
async def try_fetch_profile(
self, expected_code: int, access_token: str | None = None
) -> None:
self.request_profile(expected_code, access_token=access_token)
await self.request_profile(expected_code, access_token=access_token)
self.request_profile(
await self.request_profile(
expected_code, url_suffix="/displayname", access_token=access_token
)
self.request_profile(
await self.request_profile(
expected_code, url_suffix="/avatar_url", access_token=access_token
)
def request_profile(
async def request_profile(
self,
expected_code: int,
url_suffix: str = "",
access_token: str | None = None,
) -> None:
channel = self.make_request(
channel = await self.make_request(
"GET", self.profile_url + url_suffix, access_token=access_token
)
self.assertEqual(channel.code, expected_code, channel.result)
def ensure_requester_left_room(self) -> None:
async def ensure_requester_left_room(self) -> None:
try:
self.helper.leave(
await self.helper.leave(
room=self.room_id, user=self.requester, tok=self.requester_tok
)
except AssertionError:
@@ -888,36 +907,36 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
profile.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["require_auth_for_profile_requests"] = True
config["limit_profile_requests_to_users_who_share_rooms"] = True
self.hs = self.setup_test_homeserver(config=config)
self.hs = await self.setup_test_homeserver(config=config)
return self.hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# User requesting the profile.
self.requester = self.register_user("requester", "pass")
self.requester_tok = self.login("requester", "pass")
self.requester = await self.register_user("requester", "pass")
self.requester_tok = await self.login("requester", "pass")
def test_can_lookup_own_profile(self) -> None:
async def test_can_lookup_own_profile(self) -> None:
"""Tests that a user can lookup their own profile without having to be in a room
if 'require_auth_for_profile_requests' is set to true in the server's config.
"""
channel = self.make_request(
channel = await self.make_request(
"GET", "/profile/" + self.requester, access_token=self.requester_tok
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
channel = await self.make_request(
"GET",
"/profile/" + self.requester + "/displayname",
access_token=self.requester_tok,
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
channel = await self.make_request(
"GET",
"/profile/" + self.requester + "/avatar_url",
access_token=self.requester_tok,
+25 -25
View File
@@ -43,7 +43,7 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
admin.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# merge this default retention config with anything that was specified in
@@ -56,27 +56,27 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
retention_config.update(config.get("retention", {}))
config["retention"] = retention_config
self.hs = self.setup_test_homeserver(config=config)
self.hs = await self.setup_test_homeserver(config=config)
return self.hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.owner = self.register_user("owner", "pass")
self.owner_tok = self.login("owner", "pass")
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.owner = await self.register_user("owner", "pass")
self.owner_tok = await self.login("owner", "pass")
self.store = self.hs.get_datastores().main
self.clock = self.hs.get_clock()
def test_send_read_marker(self) -> None:
room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok)
async def test_send_read_marker(self) -> None:
room_id = await self.helper.create_room_as(self.owner, tok=self.owner_tok)
def send_message() -> str:
res = self.helper.send(room_id=room_id, body="1", tok=self.owner_tok)
async def send_message() -> str:
res = await self.helper.send(room_id=room_id, body="1", tok=self.owner_tok)
return res["event_id"]
# Test setting the read marker on the room
event_id_1 = send_message()
event_id_1 = await send_message()
channel = self.make_request(
channel = await self.make_request(
"POST",
f"/rooms/{room_id}/read_markers",
content={
@@ -87,8 +87,8 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
# Test moving the read marker to a newer event
event_id_2 = send_message()
channel = self.make_request(
event_id_2 = await send_message()
channel = await self.make_request(
"POST",
f"/rooms/{room_id}/read_markers",
content={
@@ -98,30 +98,30 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
def test_send_read_marker_missing_previous_event(self) -> None:
async def test_send_read_marker_missing_previous_event(self) -> None:
"""
Test moving a read marker from an event that previously existed but was
later removed due to retention rules.
"""
room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok)
room_id = await self.helper.create_room_as(self.owner, tok=self.owner_tok)
# Set retention rule on the room so we remove old events to test this case
self.helper.send_state(
await self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
body={"max_lifetime": ONE_DAY_MS},
tok=self.owner_tok,
)
def send_message() -> str:
res = self.helper.send(room_id=room_id, body="1", tok=self.owner_tok)
async def send_message() -> str:
res = await self.helper.send(room_id=room_id, body="1", tok=self.owner_tok)
return res["event_id"]
# Test setting the read marker on the room
event_id_1 = send_message()
event_id_1 = await send_message()
channel = self.make_request(
channel = await self.make_request(
"POST",
f"/rooms/{room_id}/read_markers",
content={
@@ -131,16 +131,16 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
)
# Send a second message (retention will not remove the latest event ever)
send_message()
await send_message()
# And then advance so retention rules remove the first event (where the marker is)
self.reactor.advance(ONE_DAY_MS * 2 / 1000)
await self.pump(ONE_DAY_MS * 2 / 1000)
event = self.get_success(self.store.get_event(event_id_1, allow_none=True))
event = await self.get_success(self.store.get_event(event_id_1, allow_none=True))
assert event is None
# Test moving the read marker to a newer event
event_id_2 = send_message()
channel = self.make_request(
event_id_2 = await send_message()
channel = await self.make_request(
"POST",
f"/rooms/{room_id}/read_markers",
content={
+14 -14
View File
@@ -42,18 +42,18 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
user = UserID.from_string(user_id)
servlets = [room.register_servlets]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("red")
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = await self.setup_test_homeserver("red")
self.event_source = hs.get_event_sources().sources.typing
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = await self.helper.create_room_as(self.user_id)
# Need another user to make notifications actually work
self.helper.join(self.room_id, user="@jim:red")
await self.helper.join(self.room_id, user="@jim:red")
def test_set_typing(self) -> None:
channel = self.make_request(
async def test_set_typing(self) -> None:
channel = await self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
@@ -61,7 +61,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code)
self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success(
events = await self.get_success(
self.event_source.get_new_events(
user=UserID.from_string(self.user_id),
from_key=0,
@@ -82,16 +82,16 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
],
)
def test_set_not_typing(self) -> None:
channel = self.make_request(
async def test_set_not_typing(self) -> None:
channel = await self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": false}',
)
self.assertEqual(200, channel.code)
def test_typing_timeout(self) -> None:
channel = self.make_request(
async def test_typing_timeout(self) -> None:
channel = await self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
@@ -100,11 +100,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.event_source.get_current_key(), 1)
self.reactor.advance(36)
await self.pump(36)
self.assertEqual(self.event_source.get_current_key(), 2)
channel = self.make_request(
channel = await self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
+16 -11
View File
@@ -202,13 +202,11 @@ class TestCase(_stdlib_unittest.IsolatedAsyncioTestCase):
"""
def setUp(self) -> None:
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
# If we're not starting in the sentinel logcontext, reset it.
# In asyncio, cancelled background tasks may leave a non-sentinel
# context during event loop shutdown between tests.
if current_context():
self.fail(
"Test starting with non-sentinel logging context %s"
% (current_context(),)
)
set_current_context(SENTINEL_CONTEXT)
# Disable GC for duration of test (re-enabled in tearDown).
gc.disable()
@@ -585,9 +583,11 @@ class HomeserverTestCase(TestCase):
"""
if hasattr(self, 'hs') and self.hs is not None:
self.hs.get_clock().shutdown()
# Give event loop time for cancellation to propagate
await asyncio.sleep(0)
await asyncio.sleep(0)
# Give event loop time for cancellation callbacks to propagate.
# Multiple yields are needed because cancellation of a task may
# trigger cleanup in other tasks (e.g. context managers exiting).
for _ in range(5):
await asyncio.sleep(0)
# Ensure we're back in the sentinel context
set_current_context(SENTINEL_CONTEXT)
@@ -819,11 +819,16 @@ class HomeserverTestCase(TestCase):
``reactor.advance()`` delegates to ``clock.advance()``, so calling
either one advances the same fake-time source.
After advancing, yields multiple times to allow background tasks
(including those running in executor threads) to complete.
"""
# Advance fake time (fires pending sleeps and callFromThread)
self.reactor.advance(by)
# Yield to the event loop so callbacks can run
await asyncio.sleep(0)
# Yield to the event loop multiple times so background tasks
# (including DB operations in executor threads) can complete.
for _ in range(20):
await asyncio.sleep(0.01)
async def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
"""Await an awaitable, optionally advancing fake time first."""