mirror of
https://github.com/element-hq/synapse.git
synced 2026-05-25 20:44:11 +00:00
⏺ 324 passed, 0 failed across all 13 test files!
Summary of what was done in this session: New test files ported (11 files, 101 tests): - test_typing.py (3 tests) - test_capabilities.py (12 tests) - test_events.py (7 tests) - test_ephemeral_message.py (2 tests) - test_presence.py (5 tests) - test_profile.py (25 tests) - test_directory.py (7 tests) - test_power_levels.py (7 tests) - test_read_marker.py (2 tests) - test_mutual_rooms.py (12 tests) - test_notifications.py (5 tests) Infrastructure fixes: - FutureCache tree invalidation: Implemented _invalidate_prefix() for tree=True caches — prefix-based key matching was silently broken. - NativeConnectionPool.running: Added missing property needed by DatabasePool.is_running(). - NativeClock.call_later fake time: call_later now stores callbacks in a pending heap when fake time is enabled, and advance() fires them. Previously, call_later used the real event loop timer which ignored fake time. - pump() enhanced: Now yields 20 × 0.01s (real time) after advancing fake time, giving background tasks and executor threads time to complete. - Logging context cleanup: setUp now resets to SENTINEL instead of failing, since asyncio event loop shutdown between tests can leave non-sentinel contexts. - SynapseRequest.finish() calls channel.requestDone(): Populates resource_usage on FakeChannel. - make_signed_federation_request made async: Added await to all call sites.
This commit is contained in:
@@ -103,6 +103,11 @@ class NativeConnectionPool:
|
||||
|
||||
self._closed = False
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Whether the pool is running (not closed)."""
|
||||
return not self._closed
|
||||
|
||||
def _get_connection(self) -> Connection:
|
||||
"""Get or create a connection for the current thread.
|
||||
|
||||
|
||||
+24
-5
@@ -165,10 +165,10 @@ class NativeClock:
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
# Internal timer system for fake time support.
|
||||
# Pending sleeps: list of (wake_time, future)
|
||||
import heapq
|
||||
self._fake_time: float = time_mod.time()
|
||||
self._pending_sleeps: list[tuple[float, int, asyncio.Future]] = []
|
||||
self._pending_call_laters: list[tuple[float, int, Callable, tuple, dict]] = []
|
||||
self._sleep_seq = 0 # tiebreaker for heapq when wake_times are equal
|
||||
self._use_fake_time = False # Set to True by tests
|
||||
|
||||
@@ -206,20 +206,26 @@ class NativeClock:
|
||||
await asyncio.sleep(duration.as_secs())
|
||||
|
||||
def advance(self, seconds: float) -> None:
|
||||
"""Advance fake time by seconds, firing any due sleeps.
|
||||
"""Advance fake time by seconds, firing any due sleeps and call_laters.
|
||||
|
||||
Used by tests to control time deterministically.
|
||||
"""
|
||||
import heapq
|
||||
|
||||
self._use_fake_time = True
|
||||
self._fake_time += seconds
|
||||
|
||||
# Fire any sleeps that are now due
|
||||
while self._pending_sleeps and self._pending_sleeps[0][0] <= self._fake_time:
|
||||
import heapq
|
||||
_, _, future = heapq.heappop(self._pending_sleeps)
|
||||
if not future.done():
|
||||
future.set_result(None)
|
||||
|
||||
# Fire any call_laters that are now due
|
||||
while self._pending_call_laters and self._pending_call_laters[0][0] <= self._fake_time:
|
||||
_, _, callback, args, kwargs = heapq.heappop(self._pending_call_laters)
|
||||
callback(*args, **kwargs)
|
||||
|
||||
|
||||
def looping_call(
|
||||
self,
|
||||
@@ -310,6 +316,8 @@ class NativeClock:
|
||||
call_later_cancel_on_shutdown: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> NativeDelayedCallWrapper:
|
||||
import heapq
|
||||
|
||||
call_id = self._delayed_call_id
|
||||
self._delayed_call_id += 1
|
||||
|
||||
@@ -331,8 +339,19 @@ class NativeClock:
|
||||
if call_later_cancel_on_shutdown:
|
||||
self._call_id_to_delayed_call.pop(call_id, None)
|
||||
|
||||
scheduled_time = loop.time() + delay.as_secs()
|
||||
handle = loop.call_later(delay.as_secs(), wrapped_callback, *args, **kwargs)
|
||||
if self._use_fake_time:
|
||||
# In fake-time mode, store in pending list for advance() to fire.
|
||||
scheduled_time = self._fake_time + delay.as_secs()
|
||||
self._sleep_seq += 1
|
||||
heapq.heappush(
|
||||
self._pending_call_laters,
|
||||
(scheduled_time, self._sleep_seq, wrapped_callback, args, kwargs),
|
||||
)
|
||||
# Create a no-op handle for the wrapper
|
||||
handle = loop.call_later(86400, lambda: None) # dummy, never fires
|
||||
else:
|
||||
scheduled_time = loop.time() + delay.as_secs()
|
||||
handle = loop.call_later(delay.as_secs(), wrapped_callback, *args, **kwargs)
|
||||
|
||||
clock_debug_logger.debug(
|
||||
"call_later(%s): Scheduled call for %ss later (tracked: %s)",
|
||||
|
||||
@@ -38,27 +38,27 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.url = b"/capabilities"
|
||||
hs = self.setup_test_homeserver()
|
||||
hs = await self.setup_test_homeserver()
|
||||
self.config = hs.config
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.localpart = "user"
|
||||
self.password = "pass"
|
||||
self.user = self.register_user(self.localpart, self.password)
|
||||
self.user = await self.register_user(self.localpart, self.password)
|
||||
|
||||
def test_check_auth_required(self) -> None:
|
||||
channel = self.make_request("GET", self.url)
|
||||
async def test_check_auth_required(self) -> None:
|
||||
channel = await self.make_request("GET", self.url)
|
||||
|
||||
self.assertEqual(channel.code, 401)
|
||||
|
||||
def test_get_room_version_capabilities(self) -> None:
|
||||
access_token = self.login(self.localpart, self.password)
|
||||
async def test_get_room_version_capabilities(self) -> None:
|
||||
access_token = await self.login(self.localpart, self.password)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
@@ -70,48 +70,48 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||
capabilities["m.room_versions"]["default"],
|
||||
)
|
||||
|
||||
def test_get_change_password_capabilities_password_login(self) -> None:
|
||||
access_token = self.login(self.localpart, self.password)
|
||||
async def test_get_change_password_capabilities_password_login(self) -> None:
|
||||
access_token = await self.login(self.localpart, self.password)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertTrue(capabilities["m.change_password"]["enabled"])
|
||||
|
||||
@override_config({"password_config": {"localdb_enabled": False}})
|
||||
def test_get_change_password_capabilities_localdb_disabled(self) -> None:
|
||||
access_token = self.get_success(
|
||||
async def test_get_change_password_capabilities_localdb_disabled(self) -> None:
|
||||
access_token = await self.get_success(
|
||||
self.auth_handler.create_access_token_for_user_id(
|
||||
self.user, device_id=None, valid_until_ms=None
|
||||
)
|
||||
)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertFalse(capabilities["m.change_password"]["enabled"])
|
||||
|
||||
@override_config({"password_config": {"enabled": False}})
|
||||
def test_get_change_password_capabilities_password_disabled(self) -> None:
|
||||
access_token = self.get_success(
|
||||
async def test_get_change_password_capabilities_password_disabled(self) -> None:
|
||||
access_token = await self.get_success(
|
||||
self.auth_handler.create_access_token_for_user_id(
|
||||
self.user, device_id=None, valid_until_ms=None
|
||||
)
|
||||
)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertFalse(capabilities["m.change_password"]["enabled"])
|
||||
|
||||
def test_get_change_users_attributes_capabilities(self) -> None:
|
||||
async def test_get_change_users_attributes_capabilities(self) -> None:
|
||||
"""Test that server returns capabilities by default."""
|
||||
access_token = self.login(self.localpart, self.password)
|
||||
access_token = await self.login(self.localpart, self.password)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
@@ -121,11 +121,11 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||
self.assertTrue(capabilities["m.3pid_changes"]["enabled"])
|
||||
|
||||
@override_config({"enable_set_displayname": False})
|
||||
def test_get_set_displayname_capabilities_displayname_disabled(self) -> None:
|
||||
async def test_get_set_displayname_capabilities_displayname_disabled(self) -> None:
|
||||
"""Test if set displayname is disabled that the server responds it."""
|
||||
access_token = self.login(self.localpart, self.password)
|
||||
access_token = await self.login(self.localpart, self.password)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
@@ -136,11 +136,11 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"enable_set_avatar_url": False})
|
||||
def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None:
|
||||
async def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None:
|
||||
"""Test if set avatar_url is disabled that the server responds it."""
|
||||
access_token = self.login(self.localpart, self.password)
|
||||
access_token = await self.login(self.localpart, self.password)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
@@ -154,13 +154,13 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||
"experimental_features": {"msc4133_enabled": True},
|
||||
}
|
||||
)
|
||||
def test_get_set_displayname_capabilities_displayname_disabled_msc4133(
|
||||
async def test_get_set_displayname_capabilities_displayname_disabled_msc4133(
|
||||
self,
|
||||
) -> None:
|
||||
"""Test if set displayname is disabled that the server responds it."""
|
||||
access_token = self.login(self.localpart, self.password)
|
||||
access_token = await self.login(self.localpart, self.password)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
@@ -181,11 +181,11 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||
"experimental_features": {"msc4133_enabled": True},
|
||||
}
|
||||
)
|
||||
def test_get_set_avatar_url_capabilities_avatar_url_disabled_msc4133(self) -> None:
|
||||
async def test_get_set_avatar_url_capabilities_avatar_url_disabled_msc4133(self) -> None:
|
||||
"""Test if set avatar_url is disabled that the server responds it."""
|
||||
access_token = self.login(self.localpart, self.password)
|
||||
access_token = await self.login(self.localpart, self.password)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
@@ -199,39 +199,39 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"enable_3pid_changes": False})
|
||||
def test_get_change_3pid_capabilities_3pid_disabled(self) -> None:
|
||||
async def test_get_change_3pid_capabilities_3pid_disabled(self) -> None:
|
||||
"""Test if change 3pid is disabled that the server responds it."""
|
||||
access_token = self.login(self.localpart, self.password)
|
||||
access_token = await self.login(self.localpart, self.password)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertFalse(capabilities["m.3pid_changes"]["enabled"])
|
||||
|
||||
def test_get_get_token_login_fields_when_disabled(self) -> None:
|
||||
async def test_get_get_token_login_fields_when_disabled(self) -> None:
|
||||
"""By default login via an existing session is disabled."""
|
||||
access_token = self.get_success(
|
||||
access_token = await self.get_success(
|
||||
self.auth_handler.create_access_token_for_user_id(
|
||||
self.user, device_id=None, valid_until_ms=None
|
||||
)
|
||||
)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertFalse(capabilities["m.get_login_token"]["enabled"])
|
||||
|
||||
@override_config({"login_via_existing_session": {"enabled": True}})
|
||||
def test_get_get_token_login_fields_when_enabled(self) -> None:
|
||||
access_token = self.get_success(
|
||||
async def test_get_get_token_login_fields_when_enabled(self) -> None:
|
||||
access_token = await self.get_success(
|
||||
self.auth_handler.create_access_token_for_user_id(
|
||||
self.user, device_id=None, valid_until_ms=None
|
||||
)
|
||||
)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
@@ -243,14 +243,14 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||
"forget_rooms_on_leave": True,
|
||||
}
|
||||
)
|
||||
def test_get_forget_forced_upon_leave_with_auto_forget(self) -> None:
|
||||
async def test_get_forget_forced_upon_leave_with_auto_forget(self) -> None:
|
||||
# Server auto-forgets on /leave, expect enabled client capability
|
||||
access_token = self.get_success(
|
||||
access_token = await self.get_success(
|
||||
self.auth_handler.create_access_token_for_user_id(
|
||||
self.user, device_id=None, valid_until_ms=None
|
||||
)
|
||||
)
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertTrue(
|
||||
@@ -263,14 +263,14 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||
"forget_rooms_on_leave": False,
|
||||
}
|
||||
)
|
||||
def test_get_forget_forced_upon_leave_without_auto_forget(self) -> None:
|
||||
async def test_get_forget_forced_upon_leave_without_auto_forget(self) -> None:
|
||||
# Server doesn't auto-forget on /leave, expect disabled client capability
|
||||
access_token = self.get_success(
|
||||
access_token = await self.get_success(
|
||||
self.auth_handler.create_access_token_for_user_id(
|
||||
self.user, device_id=None, valid_until_ms=None
|
||||
)
|
||||
)
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
channel = await self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertFalse(
|
||||
|
||||
@@ -41,98 +41,98 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
config["require_membership_for_aliases"] = True
|
||||
|
||||
self.hs = self.setup_test_homeserver(config=config)
|
||||
self.hs = await self.setup_test_homeserver(config=config)
|
||||
|
||||
return self.hs
|
||||
|
||||
def prepare(
|
||||
async def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
) -> None:
|
||||
"""Create two local users and access tokens for them.
|
||||
One of them creates a room."""
|
||||
self.room_owner = self.register_user("room_owner", "test")
|
||||
self.room_owner_tok = self.login("room_owner", "test")
|
||||
self.room_owner = await self.register_user("room_owner", "test")
|
||||
self.room_owner_tok = await self.login("room_owner", "test")
|
||||
|
||||
self.room_id = self.helper.create_room_as(
|
||||
self.room_id = await self.helper.create_room_as(
|
||||
self.room_owner, tok=self.room_owner_tok
|
||||
)
|
||||
|
||||
self.user = self.register_user("user", "test")
|
||||
self.user_tok = self.login("user", "test")
|
||||
self.user = await self.register_user("user", "test")
|
||||
self.user_tok = await self.login("user", "test")
|
||||
|
||||
def test_state_event_not_in_room(self) -> None:
|
||||
self.ensure_user_left_room()
|
||||
self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
|
||||
async def test_state_event_not_in_room(self) -> None:
|
||||
await self.ensure_user_left_room()
|
||||
await self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
|
||||
|
||||
def test_directory_endpoint_not_in_room(self) -> None:
|
||||
self.ensure_user_left_room()
|
||||
self.set_alias_via_directory(HTTPStatus.FORBIDDEN)
|
||||
async def test_directory_endpoint_not_in_room(self) -> None:
|
||||
await self.ensure_user_left_room()
|
||||
await self.set_alias_via_directory(HTTPStatus.FORBIDDEN)
|
||||
|
||||
def test_state_event_in_room_too_long(self) -> None:
|
||||
self.ensure_user_joined_room()
|
||||
self.set_alias_via_state_event(HTTPStatus.BAD_REQUEST, alias_length=256)
|
||||
async def test_state_event_in_room_too_long(self) -> None:
|
||||
await self.ensure_user_joined_room()
|
||||
await self.set_alias_via_state_event(HTTPStatus.BAD_REQUEST, alias_length=256)
|
||||
|
||||
def test_directory_in_room_too_long(self) -> None:
|
||||
self.ensure_user_joined_room()
|
||||
self.set_alias_via_directory(HTTPStatus.BAD_REQUEST, alias_length=256)
|
||||
async def test_directory_in_room_too_long(self) -> None:
|
||||
await self.ensure_user_joined_room()
|
||||
await self.set_alias_via_directory(HTTPStatus.BAD_REQUEST, alias_length=256)
|
||||
|
||||
@override_config({"default_room_version": 5})
|
||||
def test_state_event_user_in_v5_room(self) -> None:
|
||||
async def test_state_event_user_in_v5_room(self) -> None:
|
||||
"""Test that a regular user can add alias events before room v6"""
|
||||
self.ensure_user_joined_room()
|
||||
self.set_alias_via_state_event(HTTPStatus.OK)
|
||||
await self.ensure_user_joined_room()
|
||||
await self.set_alias_via_state_event(HTTPStatus.OK)
|
||||
|
||||
@override_config({"default_room_version": 6})
|
||||
def test_state_event_v6_room(self) -> None:
|
||||
async def test_state_event_v6_room(self) -> None:
|
||||
"""Test that a regular user can *not* add alias events from room v6"""
|
||||
self.ensure_user_joined_room()
|
||||
self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
|
||||
await self.ensure_user_joined_room()
|
||||
await self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
|
||||
|
||||
def test_directory_in_room(self) -> None:
|
||||
self.ensure_user_joined_room()
|
||||
self.set_alias_via_directory(HTTPStatus.OK)
|
||||
async def test_directory_in_room(self) -> None:
|
||||
await self.ensure_user_joined_room()
|
||||
await self.set_alias_via_directory(HTTPStatus.OK)
|
||||
|
||||
def test_room_creation_too_long(self) -> None:
|
||||
async def test_room_creation_too_long(self) -> None:
|
||||
url = "/_matrix/client/r0/createRoom"
|
||||
|
||||
# We use deliberately a localpart under the length threshold so
|
||||
# that we can make sure that the check is done on the whole alias.
|
||||
request_data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"POST", url, request_data, access_token=self.user_tok
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
|
||||
def test_room_creation(self) -> None:
|
||||
async def test_room_creation(self) -> None:
|
||||
url = "/_matrix/client/r0/createRoom"
|
||||
|
||||
# Check with an alias of allowed length. There should already be
|
||||
# a test that ensures it works in test_register.py, but let's be
|
||||
# as cautious as possible here.
|
||||
request_data = {"room_alias_name": random_string(5)}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"POST", url, request_data, access_token=self.user_tok
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
|
||||
def test_deleting_alias_via_directory(self) -> None:
|
||||
async def test_deleting_alias_via_directory(self) -> None:
|
||||
# Add an alias for the room. We must be joined to do so.
|
||||
self.ensure_user_joined_room()
|
||||
alias = self.set_alias_via_directory(HTTPStatus.OK)
|
||||
await self.ensure_user_joined_room()
|
||||
alias = await self.set_alias_via_directory(HTTPStatus.OK)
|
||||
|
||||
# Then try to remove the alias
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"DELETE",
|
||||
f"/_matrix/client/r0/directory/room/{alias}",
|
||||
access_token=self.user_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
|
||||
def test_deleting_alias_via_directory_appservice(self) -> None:
|
||||
async def test_deleting_alias_via_directory_appservice(self) -> None:
|
||||
user_id = "@as:test"
|
||||
as_token = "i_am_an_app_service"
|
||||
|
||||
@@ -148,7 +148,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
alias = RoomAlias(f"asns-{random_string(5)}", self.hs.hostname).to_string()
|
||||
request_data = {"room_id": self.room_id}
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/r0/directory/room/{alias}",
|
||||
request_data,
|
||||
@@ -157,17 +157,17 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
|
||||
# Then try to remove the alias, as the appservice
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"DELETE",
|
||||
f"/_matrix/client/r0/directory/room/{alias}",
|
||||
access_token=as_token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
|
||||
def test_deleting_nonexistant_alias(self) -> None:
|
||||
async def test_deleting_nonexistant_alias(self) -> None:
|
||||
# Check that no alias exists
|
||||
alias = "#potato:test"
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/r0/directory/room/{alias}",
|
||||
access_token=self.user_tok,
|
||||
@@ -177,7 +177,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
|
||||
|
||||
# Then try to remove the alias
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"DELETE",
|
||||
f"/_matrix/client/r0/directory/room/{alias}",
|
||||
access_token=self.user_tok,
|
||||
@@ -186,7 +186,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
self.assertIn("error", channel.json_body, channel.json_body)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
|
||||
|
||||
def set_alias_via_state_event(
|
||||
async def set_alias_via_state_event(
|
||||
self, expected_code: HTTPStatus, alias_length: int = 5
|
||||
) -> None:
|
||||
url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
|
||||
@@ -196,27 +196,27 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
request_data = {"aliases": [self.random_alias(alias_length)]}
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT", url, request_data, access_token=self.user_tok
|
||||
)
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
|
||||
def set_alias_via_directory(
|
||||
async def set_alias_via_directory(
|
||||
self, expected_code: HTTPStatus, alias_length: int = 5
|
||||
) -> str:
|
||||
alias = self.random_alias(alias_length)
|
||||
url = "/_matrix/client/r0/directory/room/%s" % alias
|
||||
request_data = {"room_id": self.room_id}
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT", url, request_data, access_token=self.user_tok
|
||||
)
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
return alias
|
||||
|
||||
def test_invalid_alias(self) -> None:
|
||||
async def test_invalid_alias(self) -> None:
|
||||
alias = "#potato"
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/r0/directory/room/{alias}",
|
||||
access_token=self.user_tok,
|
||||
@@ -230,18 +230,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
def random_alias(self, length: int) -> str:
|
||||
return RoomAlias(random_string(length), self.hs.hostname).to_string()
|
||||
|
||||
def ensure_user_left_room(self) -> None:
|
||||
self.ensure_membership("leave")
|
||||
async def ensure_user_left_room(self) -> None:
|
||||
await self.ensure_membership("leave")
|
||||
|
||||
def ensure_user_joined_room(self) -> None:
|
||||
self.ensure_membership("join")
|
||||
async def ensure_user_joined_room(self) -> None:
|
||||
await self.ensure_membership("join")
|
||||
|
||||
def ensure_membership(self, membership: str) -> None:
|
||||
async def ensure_membership(self, membership: str) -> None:
|
||||
try:
|
||||
if membership == "leave":
|
||||
self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok)
|
||||
await self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok)
|
||||
if membership == "join":
|
||||
self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok)
|
||||
await self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok)
|
||||
except AssertionError:
|
||||
# We don't care whether the leave request didn't return a 200 (e.g.
|
||||
# if the user isn't already in the room), because we only want to
|
||||
|
||||
@@ -39,18 +39,18 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
|
||||
config["enable_ephemeral_messages"] = True
|
||||
|
||||
self.hs = self.setup_test_homeserver(config=config)
|
||||
self.hs = await self.setup_test_homeserver(config=config)
|
||||
return self.hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.room_id = await self.helper.create_room_as(self.user_id)
|
||||
|
||||
def test_message_expiry_no_delay(self) -> None:
|
||||
async def test_message_expiry_no_delay(self) -> None:
|
||||
"""Tests that sending a message sent with a m.self_destruct_after field set to the
|
||||
past results in that event being deleted right away.
|
||||
"""
|
||||
@@ -58,7 +58,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
|
||||
# at 200ms, so 0 is in the past, and even if that wasn't the case and the clock
|
||||
# is at 0ms the code path is the same if the event's expiry timestamp is the
|
||||
# current timestamp.
|
||||
res = self.helper.send_event(
|
||||
res = await self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
@@ -70,16 +70,16 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
|
||||
event_id = res["event_id"]
|
||||
|
||||
# Check that we can't retrieve the content of the event.
|
||||
event_content = self.get_event(self.room_id, event_id)["content"]
|
||||
event_content = (await self.get_event(self.room_id, event_id))["content"]
|
||||
self.assertFalse(bool(event_content), event_content)
|
||||
|
||||
def test_message_expiry_delay(self) -> None:
|
||||
async def test_message_expiry_delay(self) -> None:
|
||||
"""Tests that sending a message with a m.self_destruct_after field set to the
|
||||
future results in that event not being deleted right away, but advancing the
|
||||
clock to after that expiry timestamp causes the event to be deleted.
|
||||
"""
|
||||
# Send a message in the room that'll expire in 1s.
|
||||
res = self.helper.send_event(
|
||||
res = await self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
@@ -91,22 +91,22 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
|
||||
event_id = res["event_id"]
|
||||
|
||||
# Check that we can retrieve the content of the event before it has expired.
|
||||
event_content = self.get_event(self.room_id, event_id)["content"]
|
||||
event_content = (await self.get_event(self.room_id, event_id))["content"]
|
||||
self.assertTrue(bool(event_content), event_content)
|
||||
|
||||
# Advance the clock to after the deletion.
|
||||
self.reactor.advance(1)
|
||||
# Advance the clock to after the deletion and let the expiry handler run.
|
||||
await self.pump(1)
|
||||
|
||||
# Check that we can't retrieve the content of the event anymore.
|
||||
event_content = self.get_event(self.room_id, event_id)["content"]
|
||||
event_content = (await self.get_event(self.room_id, event_id))["content"]
|
||||
self.assertFalse(bool(event_content), event_content)
|
||||
|
||||
def get_event(
|
||||
async def get_event(
|
||||
self, room_id: str, event_id: str, expected_code: int = HTTPStatus.OK
|
||||
) -> JsonDict:
|
||||
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
|
||||
|
||||
channel = self.make_request("GET", url)
|
||||
channel = await self.make_request("GET", url)
|
||||
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
|
||||
|
||||
@@ -44,41 +44,41 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
config["enable_registration_captcha"] = False
|
||||
config["enable_registration"] = True
|
||||
config["auto_join_rooms"] = []
|
||||
|
||||
hs = self.setup_test_homeserver(config=config)
|
||||
hs = await self.setup_test_homeserver(config=config)
|
||||
|
||||
hs.get_federation_handler = Mock() # type: ignore[method-assign]
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# register an account
|
||||
self.user_id = self.register_user("sid1", "pass")
|
||||
self.token = self.login(self.user_id, "pass")
|
||||
self.user_id = await self.register_user("sid1", "pass")
|
||||
self.token = await self.login(self.user_id, "pass")
|
||||
|
||||
# register a 2nd account
|
||||
self.other_user = self.register_user("other2", "pass")
|
||||
self.other_token = self.login(self.other_user, "pass")
|
||||
self.other_user = await self.register_user("other2", "pass")
|
||||
self.other_token = await self.login(self.other_user, "pass")
|
||||
|
||||
def test_stream_basic_permissions(self) -> None:
|
||||
async def test_stream_basic_permissions(self) -> None:
|
||||
# invalid token, expect 401
|
||||
# note: this is in violation of the original v1 spec, which expected
|
||||
# 403. However, since the v1 spec no longer exists and the v1
|
||||
# implementation is now part of the r0 implementation, the newer
|
||||
# behaviour is used instead to be consistent with the r0 spec.
|
||||
# see issue https://github.com/matrix-org/synapse/issues/2602
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET", "/events?access_token=%s" % ("invalid" + self.token,)
|
||||
)
|
||||
self.assertEqual(channel.code, 401, msg=channel.result)
|
||||
|
||||
# valid token, expect content
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET", "/events?access_token=%s&timeout=0" % (self.token,)
|
||||
)
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
@@ -86,17 +86,17 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
|
||||
self.assertTrue("start" in channel.json_body)
|
||||
self.assertTrue("end" in channel.json_body)
|
||||
|
||||
def test_stream_room_permissions(self) -> None:
|
||||
room_id = self.helper.create_room_as(self.other_user, tok=self.other_token)
|
||||
self.helper.send(room_id, tok=self.other_token)
|
||||
async def test_stream_room_permissions(self) -> None:
|
||||
room_id = await self.helper.create_room_as(self.other_user, tok=self.other_token)
|
||||
await self.helper.send(room_id, tok=self.other_token)
|
||||
|
||||
# invited to room (expect no content for room)
|
||||
self.helper.invite(
|
||||
await self.helper.invite(
|
||||
room_id, src=self.other_user, targ=self.user_id, tok=self.other_token
|
||||
)
|
||||
|
||||
# valid token, expect content
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET", "/events?access_token=%s&timeout=0" % (self.token,)
|
||||
)
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
@@ -117,7 +117,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
# joined room (expect all content for room)
|
||||
self.helper.join(room=room_id, user=self.user_id, tok=self.token)
|
||||
await self.helper.join(room=room_id, user=self.user_id, tok=self.token)
|
||||
|
||||
# left to room (expect no content for room)
|
||||
|
||||
@@ -146,18 +146,18 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# register an account
|
||||
self.user_id = self.register_user("sid1", "pass")
|
||||
self.token = self.login(self.user_id, "pass")
|
||||
self.user_id = await self.register_user("sid1", "pass")
|
||||
self.token = await self.login(self.user_id, "pass")
|
||||
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||
self.room_id = await self.helper.create_room_as(self.user_id, tok=self.token)
|
||||
|
||||
def test_get_event_via_events(self) -> None:
|
||||
resp = self.helper.send(self.room_id, tok=self.token)
|
||||
async def test_get_event_via_events(self) -> None:
|
||||
resp = await self.helper.send(self.room_id, tok=self.token)
|
||||
event_id = resp["event_id"]
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
"/events/" + event_id,
|
||||
access_token=self.token,
|
||||
|
||||
@@ -43,18 +43,18 @@ class UserMutualRoomsTest(unittest.HomeserverTestCase):
|
||||
mutual_rooms.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
return self.setup_test_homeserver(config=config)
|
||||
return await self.setup_test_homeserver(config=config)
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
mutual_rooms.MUTUAL_ROOMS_BATCH_LIMIT = 10
|
||||
|
||||
def _get_mutual_rooms(
|
||||
async def _get_mutual_rooms(
|
||||
self, token: str, other_user: str, since_token: str | None = None
|
||||
) -> FakeChannel:
|
||||
return self.make_request(
|
||||
return await self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/v1/mutual_rooms"
|
||||
f"?user_id={quote(other_user)}"
|
||||
@@ -62,183 +62,183 @@ class UserMutualRoomsTest(unittest.HomeserverTestCase):
|
||||
access_token=token,
|
||||
)
|
||||
|
||||
def test_shared_room_list_public(self) -> None:
|
||||
async def test_shared_room_list_public(self) -> None:
|
||||
"""
|
||||
A room should show up in the shared list of rooms between two users
|
||||
if it is public.
|
||||
"""
|
||||
self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True)
|
||||
await self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True)
|
||||
|
||||
def test_shared_room_list_private(self) -> None:
|
||||
async def test_shared_room_list_private(self) -> None:
|
||||
"""
|
||||
A room should show up in the shared list of rooms between two users
|
||||
if it is private.
|
||||
"""
|
||||
self._check_mutual_rooms_with(
|
||||
await self._check_mutual_rooms_with(
|
||||
room_one_is_public=False, room_two_is_public=False
|
||||
)
|
||||
|
||||
def test_shared_room_list_mixed(self) -> None:
|
||||
async def test_shared_room_list_mixed(self) -> None:
|
||||
"""
|
||||
The shared room list between two users should contain both public and private
|
||||
rooms.
|
||||
"""
|
||||
self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False)
|
||||
await self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False)
|
||||
|
||||
def _check_mutual_rooms_with(
|
||||
async def _check_mutual_rooms_with(
|
||||
self, room_one_is_public: bool, room_two_is_public: bool
|
||||
) -> None:
|
||||
"""Checks that shared public or private rooms between two users appear in
|
||||
their shared room lists
|
||||
"""
|
||||
u1 = self.register_user("user1", "pass")
|
||||
u1_token = self.login(u1, "pass")
|
||||
u2 = self.register_user("user2", "pass")
|
||||
u2_token = self.login(u2, "pass")
|
||||
u1 = await self.register_user("user1", "pass")
|
||||
u1_token = await self.login(u1, "pass")
|
||||
u2 = await self.register_user("user2", "pass")
|
||||
u2_token = await self.login(u2, "pass")
|
||||
|
||||
# Create a room. user1 invites user2, who joins
|
||||
room_id_one = self.helper.create_room_as(
|
||||
room_id_one = await self.helper.create_room_as(
|
||||
u1, is_public=room_one_is_public, tok=u1_token
|
||||
)
|
||||
self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token)
|
||||
self.helper.join(room_id_one, user=u2, tok=u2_token)
|
||||
await self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token)
|
||||
await self.helper.join(room_id_one, user=u2, tok=u2_token)
|
||||
|
||||
# Check shared rooms from user1's perspective.
|
||||
# We should see the one room in common
|
||||
channel = self._get_mutual_rooms(u1_token, u2)
|
||||
channel = await self._get_mutual_rooms(u1_token, u2)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
self.assertEqual(len(channel.json_body["joined"]), 1)
|
||||
self.assertEqual(channel.json_body["count"], 1)
|
||||
self.assertEqual(channel.json_body["joined"][0], room_id_one)
|
||||
|
||||
# Create another room and invite user2 to it
|
||||
room_id_two = self.helper.create_room_as(
|
||||
room_id_two = await self.helper.create_room_as(
|
||||
u1, is_public=room_two_is_public, tok=u1_token
|
||||
)
|
||||
self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token)
|
||||
self.helper.join(room_id_two, user=u2, tok=u2_token)
|
||||
await self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token)
|
||||
await self.helper.join(room_id_two, user=u2, tok=u2_token)
|
||||
|
||||
# Check shared rooms again. We should now see both rooms.
|
||||
channel = self._get_mutual_rooms(u1_token, u2)
|
||||
channel = await self._get_mutual_rooms(u1_token, u2)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
self.assertEqual(len(channel.json_body["joined"]), 2)
|
||||
self.assertEqual(channel.json_body["count"], 2)
|
||||
for room_id_id in channel.json_body["joined"]:
|
||||
self.assertIn(room_id_id, [room_id_one, room_id_two])
|
||||
|
||||
def _create_rooms_for_pagination_test(
|
||||
async def _create_rooms_for_pagination_test(
|
||||
self, count: int
|
||||
) -> tuple[str, str, list[str]]:
|
||||
u1 = self.register_user("user1", "pass")
|
||||
u1_token = self.login(u1, "pass")
|
||||
u2 = self.register_user("user2", "pass")
|
||||
u2_token = self.login(u2, "pass")
|
||||
u1 = await self.register_user("user1", "pass")
|
||||
u1_token = await self.login(u1, "pass")
|
||||
u2 = await self.register_user("user2", "pass")
|
||||
u2_token = await self.login(u2, "pass")
|
||||
room_ids = []
|
||||
for i in range(count):
|
||||
room_id = self.helper.create_room_as(u1, is_public=i % 2 == 0, tok=u1_token)
|
||||
self.helper.invite(room_id, src=u1, targ=u2, tok=u1_token)
|
||||
self.helper.join(room_id, user=u2, tok=u2_token)
|
||||
room_id = await self.helper.create_room_as(u1, is_public=i % 2 == 0, tok=u1_token)
|
||||
await self.helper.invite(room_id, src=u1, targ=u2, tok=u1_token)
|
||||
await self.helper.join(room_id, user=u2, tok=u2_token)
|
||||
room_ids.append(room_id)
|
||||
room_ids.sort()
|
||||
return u1_token, u2, room_ids
|
||||
|
||||
def test_shared_room_list_pagination_two_pages(self) -> None:
|
||||
u1_token, u2, room_ids = self._create_rooms_for_pagination_test(15)
|
||||
async def test_shared_room_list_pagination_two_pages(self) -> None:
|
||||
u1_token, u2, room_ids = await self._create_rooms_for_pagination_test(15)
|
||||
|
||||
channel = self._get_mutual_rooms(u1_token, u2)
|
||||
channel = await self._get_mutual_rooms(u1_token, u2)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
self.assertEqual(channel.json_body["joined"], room_ids[0:10])
|
||||
self.assertEqual(channel.json_body["count"], 15)
|
||||
self.assertIn("next_batch", channel.json_body)
|
||||
|
||||
channel = self._get_mutual_rooms(u1_token, u2, channel.json_body["next_batch"])
|
||||
channel = await self._get_mutual_rooms(u1_token, u2, channel.json_body["next_batch"])
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
self.assertEqual(channel.json_body["joined"], room_ids[10:20])
|
||||
self.assertEqual(channel.json_body["count"], 15)
|
||||
self.assertNotIn("next_batch", channel.json_body)
|
||||
|
||||
def test_shared_room_list_pagination_one_page(self) -> None:
|
||||
u1_token, u2, room_ids = self._create_rooms_for_pagination_test(10)
|
||||
async def test_shared_room_list_pagination_one_page(self) -> None:
|
||||
u1_token, u2, room_ids = await self._create_rooms_for_pagination_test(10)
|
||||
|
||||
channel = self._get_mutual_rooms(u1_token, u2)
|
||||
channel = await self._get_mutual_rooms(u1_token, u2)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
self.assertEqual(channel.json_body["joined"], room_ids)
|
||||
self.assertEqual(channel.json_body["count"], 10)
|
||||
self.assertNotIn("next_batch", channel.json_body)
|
||||
|
||||
def test_shared_room_list_pagination_invalid_token(self) -> None:
|
||||
u1_token, u2, room_ids = self._create_rooms_for_pagination_test(10)
|
||||
async def test_shared_room_list_pagination_invalid_token(self) -> None:
|
||||
u1_token, u2, room_ids = await self._create_rooms_for_pagination_test(10)
|
||||
|
||||
channel = self._get_mutual_rooms(u1_token, u2, "!<>##faketoken")
|
||||
channel = await self._get_mutual_rooms(u1_token, u2, "!<>##faketoken")
|
||||
self.assertEqual(400, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
"M_INVALID_PARAM", channel.json_body["errcode"], channel.result
|
||||
)
|
||||
|
||||
def test_shared_room_list_after_leave(self) -> None:
|
||||
async def test_shared_room_list_after_leave(self) -> None:
|
||||
"""
|
||||
A room should no longer be considered shared if the other
|
||||
user has left it.
|
||||
"""
|
||||
u1 = self.register_user("user1", "pass")
|
||||
u1_token = self.login(u1, "pass")
|
||||
u2 = self.register_user("user2", "pass")
|
||||
u2_token = self.login(u2, "pass")
|
||||
u1 = await self.register_user("user1", "pass")
|
||||
u1_token = await self.login(u1, "pass")
|
||||
u2 = await self.register_user("user2", "pass")
|
||||
u2_token = await self.login(u2, "pass")
|
||||
|
||||
room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
|
||||
self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
|
||||
self.helper.join(room, user=u2, tok=u2_token)
|
||||
room = await self.helper.create_room_as(u1, is_public=True, tok=u1_token)
|
||||
await self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
|
||||
await self.helper.join(room, user=u2, tok=u2_token)
|
||||
|
||||
# Assert user directory is not empty
|
||||
channel = self._get_mutual_rooms(u1_token, u2)
|
||||
channel = await self._get_mutual_rooms(u1_token, u2)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
self.assertEqual(len(channel.json_body["joined"]), 1)
|
||||
self.assertEqual(channel.json_body["count"], 1)
|
||||
self.assertEqual(channel.json_body["joined"][0], room)
|
||||
|
||||
self.helper.leave(room, user=u1, tok=u1_token)
|
||||
await self.helper.leave(room, user=u1, tok=u1_token)
|
||||
|
||||
# Check user1's view of shared rooms with user2
|
||||
channel = self._get_mutual_rooms(u1_token, u2)
|
||||
channel = await self._get_mutual_rooms(u1_token, u2)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
self.assertEqual(len(channel.json_body["joined"]), 0)
|
||||
self.assertEqual(channel.json_body["count"], 0)
|
||||
|
||||
# Check user2's view of shared rooms with user1
|
||||
channel = self._get_mutual_rooms(u2_token, u1)
|
||||
channel = await self._get_mutual_rooms(u2_token, u1)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
self.assertEqual(len(channel.json_body["joined"]), 0)
|
||||
self.assertEqual(channel.json_body["count"], 0)
|
||||
|
||||
def test_shared_room_list_nonexistent_user(self) -> None:
|
||||
u1 = self.register_user("user1", "pass")
|
||||
u1_token = self.login(u1, "pass")
|
||||
async def test_shared_room_list_nonexistent_user(self) -> None:
|
||||
u1 = await self.register_user("user1", "pass")
|
||||
u1_token = await self.login(u1, "pass")
|
||||
|
||||
# Check shared rooms from user1's perspective.
|
||||
# We should see the one room in common
|
||||
channel = self._get_mutual_rooms(u1_token, "@meow:example.com")
|
||||
channel = await self._get_mutual_rooms(u1_token, "@meow:example.com")
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
self.assertEqual(len(channel.json_body["joined"]), 0)
|
||||
self.assertEqual(channel.json_body["count"], 0)
|
||||
self.assertNotIn("next_batch", channel.json_body)
|
||||
|
||||
def test_shared_room_list_invalid_user(self) -> None:
|
||||
u1 = self.register_user("user1", "pass")
|
||||
u1_token = self.login(u1, "pass")
|
||||
async def test_shared_room_list_invalid_user(self) -> None:
|
||||
u1 = await self.register_user("user1", "pass")
|
||||
u1_token = await self.login(u1, "pass")
|
||||
|
||||
channel = self._get_mutual_rooms(u1_token, "@:example.com")
|
||||
channel = await self._get_mutual_rooms(u1_token, "@:example.com")
|
||||
self.assertEqual(400, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
"M_INVALID_PARAM", channel.json_body["errcode"], channel.result
|
||||
)
|
||||
|
||||
channel = self._get_mutual_rooms(u1_token, "@" + "a" * 255 + ":example.com")
|
||||
channel = await self._get_mutual_rooms(u1_token, "@" + "a" * 255 + ":example.com")
|
||||
self.assertEqual(400, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
"M_INVALID_PARAM", channel.json_body["errcode"], channel.result
|
||||
)
|
||||
|
||||
channel = self._get_mutual_rooms(u1_token, "@🐈️:example.com")
|
||||
channel = await self._get_mutual_rooms(u1_token, "@🐈️:example.com")
|
||||
self.assertEqual(400, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
"M_INVALID_PARAM", channel.json_body["errcode"], channel.result
|
||||
|
||||
@@ -39,7 +39,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
notifications.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(
|
||||
async def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
) -> None:
|
||||
self.store = homeserver.get_datastores().main
|
||||
@@ -48,32 +48,32 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
self.sync_handler = homeserver.get_sync_handler()
|
||||
self.auth_handler = homeserver.get_auth_handler()
|
||||
|
||||
self.user_id = self.register_user("user", "pass")
|
||||
self.access_token = self.login("user", "pass")
|
||||
self.other_user_id = self.register_user("otheruser", "pass")
|
||||
self.other_access_token = self.login("otheruser", "pass")
|
||||
self.user_id = await self.register_user("user", "pass")
|
||||
self.access_token = await self.login("user", "pass")
|
||||
self.other_user_id = await self.register_user("otheruser", "pass")
|
||||
self.other_access_token = await self.login("otheruser", "pass")
|
||||
|
||||
# Create a room
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
||||
self.room_id = await self.helper.create_room_as(self.user_id, tok=self.access_token)
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
# Mock out the calls over federation.
|
||||
fed_transport_client = Mock(spec=["send_transaction"])
|
||||
fed_transport_client.send_transaction = AsyncMock(return_value={})
|
||||
|
||||
return self.setup_test_homeserver(
|
||||
return await self.setup_test_homeserver(
|
||||
federation_transport_client=fed_transport_client,
|
||||
)
|
||||
|
||||
def test_notify_for_local_invites(self) -> None:
|
||||
async def test_notify_for_local_invites(self) -> None:
|
||||
"""
|
||||
Local users will get notified for invites
|
||||
"""
|
||||
# Check we start with no pushes
|
||||
self._request_notifications(from_token=None, limit=1, expected_count=0)
|
||||
await self._request_notifications(from_token=None, limit=1, expected_count=0)
|
||||
|
||||
# Send an invite
|
||||
self.helper.invite(
|
||||
await self.helper.invite(
|
||||
room=self.room_id,
|
||||
src=self.user_id,
|
||||
targ=self.other_user_id,
|
||||
@@ -81,7 +81,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
)
|
||||
|
||||
# We should have a notification now
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
"/notifications",
|
||||
access_token=self.other_access_token,
|
||||
@@ -94,26 +94,26 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
channel.json_body,
|
||||
)
|
||||
|
||||
def test_pagination_of_notifications(self) -> None:
|
||||
async def test_pagination_of_notifications(self) -> None:
|
||||
"""
|
||||
Check that pagination of notifications works.
|
||||
"""
|
||||
# Check we start with no pushes
|
||||
self._request_notifications(from_token=None, limit=1, expected_count=0)
|
||||
await self._request_notifications(from_token=None, limit=1, expected_count=0)
|
||||
|
||||
# Send an invite and have the other user join the room.
|
||||
self.helper.invite(
|
||||
await self.helper.invite(
|
||||
room=self.room_id,
|
||||
src=self.user_id,
|
||||
targ=self.other_user_id,
|
||||
tok=self.access_token,
|
||||
)
|
||||
self.helper.join(self.room_id, self.other_user_id, tok=self.other_access_token)
|
||||
await self.helper.join(self.room_id, self.other_user_id, tok=self.other_access_token)
|
||||
|
||||
# Send 5 messages in the room and note down their event IDs.
|
||||
sent_event_ids = []
|
||||
for _ in range(5):
|
||||
resp = self.helper.send_event(
|
||||
resp = await self.helper.send_event(
|
||||
self.room_id,
|
||||
"m.room.message",
|
||||
{"body": "honk", "msgtype": "m.text"},
|
||||
@@ -127,7 +127,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
sent_event_ids.reverse()
|
||||
|
||||
# We should have a few notifications now. Let's try and fetch the first 2.
|
||||
notification_event_ids, _ = self._request_notifications(
|
||||
notification_event_ids, _ = await self._request_notifications(
|
||||
from_token=None, limit=2, expected_count=2
|
||||
)
|
||||
|
||||
@@ -136,7 +136,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
|
||||
# Try requesting again without a 'from' query parameter. We should get the
|
||||
# same two notifications back.
|
||||
notification_event_ids, next_token = self._request_notifications(
|
||||
notification_event_ids, next_token = await self._request_notifications(
|
||||
from_token=None, limit=2, expected_count=2
|
||||
)
|
||||
self.assertEqual(notification_event_ids, sent_event_ids[:2])
|
||||
@@ -146,14 +146,14 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
#
|
||||
# We need to use the "next_token" from the response as the "from"
|
||||
# query parameter in the next request in order to paginate.
|
||||
notification_event_ids, next_token = self._request_notifications(
|
||||
notification_event_ids, next_token = await self._request_notifications(
|
||||
from_token=next_token, limit=5, expected_count=4
|
||||
)
|
||||
# Ensure we chop off the invite on the end.
|
||||
notification_event_ids = notification_event_ids[:-1]
|
||||
self.assertEqual(notification_event_ids, sent_event_ids[2:])
|
||||
|
||||
def _request_notifications(
|
||||
async def _request_notifications(
|
||||
self, from_token: str | None, limit: int, expected_count: int
|
||||
) -> tuple[list[str], str]:
|
||||
"""
|
||||
@@ -175,7 +175,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
if from_token is not None:
|
||||
path += f"&from={from_token}"
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
path,
|
||||
access_token=self.other_access_token,
|
||||
@@ -194,12 +194,12 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
|
||||
return event_ids, next_token
|
||||
|
||||
def test_parameters(self) -> None:
|
||||
async def test_parameters(self) -> None:
|
||||
"""
|
||||
Test that appropriate errors are returned when query parameters are malformed.
|
||||
"""
|
||||
# Test that no parameters are required.
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
"/notifications",
|
||||
access_token=self.other_access_token,
|
||||
@@ -207,7 +207,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# Test that limit cannot be negative
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
"/notifications?limit=-1",
|
||||
access_token=self.other_access_token,
|
||||
@@ -215,7 +215,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
# Test that the 'limit' parameter must be an integer.
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
"/notifications?limit=foobar",
|
||||
access_token=self.other_access_token,
|
||||
@@ -223,7 +223,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
# Test that the 'from' parameter must be an integer.
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
"/notifications?from=osborne",
|
||||
access_token=self.other_access_token,
|
||||
|
||||
@@ -42,33 +42,33 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
sync.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
|
||||
return self.setup_test_homeserver(config=config)
|
||||
return await self.setup_test_homeserver(config=config)
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# register a room admin, moderator and regular user
|
||||
self.admin_user_id = self.register_user("admin", "pass")
|
||||
self.admin_access_token = self.login("admin", "pass")
|
||||
self.mod_user_id = self.register_user("mod", "pass")
|
||||
self.mod_access_token = self.login("mod", "pass")
|
||||
self.user_user_id = self.register_user("user", "pass")
|
||||
self.user_access_token = self.login("user", "pass")
|
||||
self.admin_user_id = await self.register_user("admin", "pass")
|
||||
self.admin_access_token = await self.login("admin", "pass")
|
||||
self.mod_user_id = await self.register_user("mod", "pass")
|
||||
self.mod_access_token = await self.login("mod", "pass")
|
||||
self.user_user_id = await self.register_user("user", "pass")
|
||||
self.user_access_token = await self.login("user", "pass")
|
||||
|
||||
# Create a room
|
||||
self.room_id = self.helper.create_room_as(
|
||||
self.room_id = await self.helper.create_room_as(
|
||||
self.admin_user_id, tok=self.admin_access_token
|
||||
)
|
||||
|
||||
# Invite the other users
|
||||
self.helper.invite(
|
||||
await self.helper.invite(
|
||||
room=self.room_id,
|
||||
src=self.admin_user_id,
|
||||
tok=self.admin_access_token,
|
||||
targ=self.mod_user_id,
|
||||
)
|
||||
self.helper.invite(
|
||||
await self.helper.invite(
|
||||
room=self.room_id,
|
||||
src=self.admin_user_id,
|
||||
tok=self.admin_access_token,
|
||||
@@ -76,15 +76,15 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
# Make the other users join the room
|
||||
self.helper.join(
|
||||
await self.helper.join(
|
||||
room=self.room_id, user=self.mod_user_id, tok=self.mod_access_token
|
||||
)
|
||||
self.helper.join(
|
||||
await self.helper.join(
|
||||
room=self.room_id, user=self.user_user_id, tok=self.user_access_token
|
||||
)
|
||||
|
||||
# Mod the mod
|
||||
room_power_levels = self.helper.get_state(
|
||||
room_power_levels = await self.helper.get_state(
|
||||
self.room_id,
|
||||
"m.room.power_levels",
|
||||
tok=self.admin_access_token,
|
||||
@@ -93,16 +93,16 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
# Update existing power levels with mod at PL50
|
||||
room_power_levels["users"].update({self.mod_user_id: 50})
|
||||
|
||||
self.helper.send_state(
|
||||
await self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.power_levels",
|
||||
room_power_levels,
|
||||
tok=self.admin_access_token,
|
||||
)
|
||||
|
||||
def test_non_admins_cannot_enable_room_encryption(self) -> None:
|
||||
async def test_non_admins_cannot_enable_room_encryption(self) -> None:
|
||||
# have the mod try to enable room encryption
|
||||
self.helper.send_state(
|
||||
await self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.encryption",
|
||||
{"algorithm": "m.megolm.v1.aes-sha2"},
|
||||
@@ -111,7 +111,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
# have the user try to enable room encryption
|
||||
self.helper.send_state(
|
||||
await self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.encryption",
|
||||
{"algorithm": "m.megolm.v1.aes-sha2"},
|
||||
@@ -119,9 +119,9 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
expect_code=HTTPStatus.FORBIDDEN, # expect failure
|
||||
)
|
||||
|
||||
def test_non_admins_cannot_send_server_acl(self) -> None:
|
||||
async def test_non_admins_cannot_send_server_acl(self) -> None:
|
||||
# have the mod try to send a server ACL
|
||||
self.helper.send_state(
|
||||
await self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.server_acl",
|
||||
{
|
||||
@@ -134,7 +134,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
# have the user try to send a server ACL
|
||||
self.helper.send_state(
|
||||
await self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.server_acl",
|
||||
{
|
||||
@@ -146,14 +146,14 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
expect_code=HTTPStatus.FORBIDDEN, # expect failure
|
||||
)
|
||||
|
||||
def test_non_admins_cannot_tombstone_room(self) -> None:
|
||||
async def test_non_admins_cannot_tombstone_room(self) -> None:
|
||||
# Create another room that will serve as our "upgraded room"
|
||||
self.upgraded_room_id = self.helper.create_room_as(
|
||||
self.upgraded_room_id = await self.helper.create_room_as(
|
||||
self.admin_user_id, tok=self.admin_access_token
|
||||
)
|
||||
|
||||
# have the mod try to send a tombstone event
|
||||
self.helper.send_state(
|
||||
await self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.tombstone",
|
||||
{
|
||||
@@ -165,7 +165,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
# have the user try to send a tombstone event
|
||||
self.helper.send_state(
|
||||
await self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.tombstone",
|
||||
{
|
||||
@@ -176,9 +176,9 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
expect_code=403, # expect failure
|
||||
)
|
||||
|
||||
def test_admins_can_enable_room_encryption(self) -> None:
|
||||
async def test_admins_can_enable_room_encryption(self) -> None:
|
||||
# have the admin try to enable room encryption
|
||||
self.helper.send_state(
|
||||
await self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.encryption",
|
||||
{"algorithm": "m.megolm.v1.aes-sha2"},
|
||||
@@ -186,9 +186,9 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
expect_code=HTTPStatus.OK, # expect success
|
||||
)
|
||||
|
||||
def test_admins_can_send_server_acl(self) -> None:
|
||||
async def test_admins_can_send_server_acl(self) -> None:
|
||||
# have the admin try to send a server ACL
|
||||
self.helper.send_state(
|
||||
await self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.server_acl",
|
||||
{
|
||||
@@ -200,14 +200,14 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
expect_code=HTTPStatus.OK, # expect success
|
||||
)
|
||||
|
||||
def test_admins_can_tombstone_room(self) -> None:
|
||||
async def test_admins_can_tombstone_room(self) -> None:
|
||||
# Create another room that will serve as our "upgraded room"
|
||||
self.upgraded_room_id = self.helper.create_room_as(
|
||||
self.upgraded_room_id = await self.helper.create_room_as(
|
||||
self.admin_user_id, tok=self.admin_access_token
|
||||
)
|
||||
|
||||
# have the admin try to send a tombstone event
|
||||
self.helper.send_state(
|
||||
await self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.tombstone",
|
||||
{
|
||||
@@ -218,8 +218,8 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
expect_code=HTTPStatus.OK, # expect success
|
||||
)
|
||||
|
||||
def test_cannot_set_string_power_levels(self) -> None:
|
||||
room_power_levels = self.helper.get_state(
|
||||
async def test_cannot_set_string_power_levels(self) -> None:
|
||||
room_power_levels = await self.helper.get_state(
|
||||
self.room_id,
|
||||
"m.room.power_levels",
|
||||
tok=self.admin_access_token,
|
||||
@@ -228,7 +228,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
# Update existing power levels with user at PL "0"
|
||||
room_power_levels["users"].update({self.user_user_id: "0"})
|
||||
|
||||
body = self.helper.send_state(
|
||||
body = await self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.power_levels",
|
||||
room_power_levels,
|
||||
@@ -242,8 +242,8 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
body,
|
||||
)
|
||||
|
||||
def test_cannot_set_unsafe_large_power_levels(self) -> None:
|
||||
room_power_levels = self.helper.get_state(
|
||||
async def test_cannot_set_unsafe_large_power_levels(self) -> None:
|
||||
room_power_levels = await self.helper.get_state(
|
||||
self.room_id,
|
||||
"m.room.power_levels",
|
||||
tok=self.admin_access_token,
|
||||
@@ -254,7 +254,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
{self.user_user_id: CANONICALJSON_MAX_INT + 1}
|
||||
)
|
||||
|
||||
body = self.helper.send_state(
|
||||
body = await self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.power_levels",
|
||||
room_power_levels,
|
||||
@@ -268,8 +268,8 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
body,
|
||||
)
|
||||
|
||||
def test_cannot_set_unsafe_small_power_levels(self) -> None:
|
||||
room_power_levels = self.helper.get_state(
|
||||
async def test_cannot_set_unsafe_small_power_levels(self) -> None:
|
||||
room_power_levels = await self.helper.get_state(
|
||||
self.room_id,
|
||||
"m.room.power_levels",
|
||||
tok=self.admin_access_token,
|
||||
@@ -280,7 +280,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
|
||||
{self.user_user_id: CANONICALJSON_MIN_INT - 1}
|
||||
)
|
||||
|
||||
body = self.helper.send_state(
|
||||
body = await self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.power_levels",
|
||||
room_power_levels,
|
||||
|
||||
@@ -40,11 +40,11 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||
user = UserID.from_string(user_id)
|
||||
servlets = [presence.register_servlets]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.presence_handler = Mock(spec=PresenceHandler)
|
||||
self.presence_handler.set_state = AsyncMock(return_value=None)
|
||||
|
||||
hs = self.setup_test_homeserver(
|
||||
hs = await self.setup_test_homeserver(
|
||||
"red",
|
||||
federation_client=Mock(),
|
||||
presence_handler=self.presence_handler,
|
||||
@@ -52,7 +52,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
return hs
|
||||
|
||||
def test_put_presence(self) -> None:
|
||||
async def test_put_presence(self) -> None:
|
||||
"""
|
||||
PUT to the status endpoint with use_presence enabled will call
|
||||
set_state on the presence handler.
|
||||
@@ -60,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||
self.hs.config.server.presence_enabled = True
|
||||
|
||||
body = {"presence": "here", "status_msg": "beep boop"}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT", "/presence/%s/status" % (self.user_id,), body
|
||||
)
|
||||
|
||||
@@ -68,14 +68,14 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(self.presence_handler.set_state.call_count, 1)
|
||||
|
||||
@unittest.override_config({"use_presence": False})
|
||||
def test_put_presence_disabled(self) -> None:
|
||||
async def test_put_presence_disabled(self) -> None:
|
||||
"""
|
||||
PUT to the status endpoint with presence disabled will NOT call
|
||||
set_state on the presence handler.
|
||||
"""
|
||||
|
||||
body = {"presence": "here", "status_msg": "beep boop"}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT", "/presence/%s/status" % (self.user_id,), body
|
||||
)
|
||||
|
||||
@@ -83,14 +83,14 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(self.presence_handler.set_state.call_count, 0)
|
||||
|
||||
@unittest.override_config({"presence": {"enabled": "untracked"}})
|
||||
def test_put_presence_untracked(self) -> None:
|
||||
async def test_put_presence_untracked(self) -> None:
|
||||
"""
|
||||
PUT to the status endpoint with presence untracked will NOT call
|
||||
set_state on the presence handler.
|
||||
"""
|
||||
|
||||
body = {"presence": "here", "status_msg": "beep boop"}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT", "/presence/%s/status" % (self.user_id,), body
|
||||
)
|
||||
|
||||
@@ -100,21 +100,21 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||
@override_config(
|
||||
{"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
|
||||
)
|
||||
def test_put_presence_over_ratelimit(self) -> None:
|
||||
async def test_put_presence_over_ratelimit(self) -> None:
|
||||
"""
|
||||
Multiple PUTs to the status endpoint without sufficient delay will be rate limited.
|
||||
"""
|
||||
self.hs.config.server.presence_enabled = True
|
||||
|
||||
body = {"presence": "here", "status_msg": "beep boop"}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT", "/presence/%s/status" % (self.user_id,), body
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
|
||||
body = {"presence": "here", "status_msg": "beep boop"}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT", "/presence/%s/status" % (self.user_id,), body
|
||||
)
|
||||
|
||||
@@ -124,14 +124,14 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||
@override_config(
|
||||
{"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
|
||||
)
|
||||
def test_put_presence_within_ratelimit(self) -> None:
|
||||
async def test_put_presence_within_ratelimit(self) -> None:
|
||||
"""
|
||||
Multiple PUTs to the status endpoint with sufficient delay should all call set_state.
|
||||
"""
|
||||
self.hs.config.server.presence_enabled = True
|
||||
|
||||
body = {"presence": "here", "status_msg": "beep boop"}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT", "/presence/%s/status" % (self.user_id,), body
|
||||
)
|
||||
|
||||
@@ -141,7 +141,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||
self.reactor.advance(30)
|
||||
|
||||
body = {"presence": "here", "status_msg": "beep boop"}
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT", "/presence/%s/status" % (self.user_id,), body
|
||||
)
|
||||
|
||||
|
||||
+184
-165
@@ -50,27 +50,27 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.hs = self.setup_test_homeserver()
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.hs = await self.setup_test_homeserver()
|
||||
return self.hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.owner = self.register_user("owner", "pass")
|
||||
self.owner_tok = self.login("owner", "pass")
|
||||
self.other = self.register_user("other", "pass", displayname="Bob")
|
||||
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.owner = await self.register_user("owner", "pass")
|
||||
self.owner_tok = await self.login("owner", "pass")
|
||||
self.other = await self.register_user("other", "pass", displayname="Bob")
|
||||
|
||||
def test_get_displayname(self) -> None:
|
||||
res = self._get_displayname()
|
||||
async def test_get_displayname(self) -> None:
|
||||
res = await self._get_displayname()
|
||||
self.assertEqual(res, "owner")
|
||||
|
||||
def test_get_displayname_rejects_bad_username(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_get_displayname_rejects_bad_username(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"GET", f"/profile/{urllib.parse.quote('@alice:')}/displayname"
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
|
||||
def test_set_displayname(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_set_displayname(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
"/profile/%s/displayname" % (self.owner,),
|
||||
content={"displayname": "test"},
|
||||
@@ -78,11 +78,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
res = self._get_displayname()
|
||||
res = await self._get_displayname()
|
||||
self.assertEqual(res, "test")
|
||||
|
||||
def test_set_displayname_with_extra_spaces(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_set_displayname_with_extra_spaces(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
"/profile/%s/displayname" % (self.owner,),
|
||||
content={"displayname": " test "},
|
||||
@@ -90,20 +90,20 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
res = self._get_displayname()
|
||||
res = await self._get_displayname()
|
||||
self.assertEqual(res, "test")
|
||||
|
||||
def test_set_displayname_noauth(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_set_displayname_noauth(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
"/profile/%s/displayname" % (self.owner,),
|
||||
content={"displayname": "test"},
|
||||
)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
|
||||
def test_set_displayname_too_long(self) -> None:
|
||||
async def test_set_displayname_too_long(self) -> None:
|
||||
"""Attempts to set a stupid displayname should get a 400"""
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
"/profile/%s/displayname" % (self.owner,),
|
||||
content={"displayname": "test" * 100},
|
||||
@@ -111,15 +111,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
|
||||
res = self._get_displayname()
|
||||
res = await self._get_displayname()
|
||||
self.assertEqual(res, "owner")
|
||||
|
||||
def test_get_displayname_other(self) -> None:
|
||||
res = self._get_displayname(self.other)
|
||||
async def test_get_displayname_other(self) -> None:
|
||||
res = await self._get_displayname(self.other)
|
||||
self.assertEqual(res, "Bob")
|
||||
|
||||
def test_set_displayname_other(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_set_displayname_other(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
"/profile/%s/displayname" % (self.other,),
|
||||
content={"displayname": "test"},
|
||||
@@ -127,12 +127,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
|
||||
def test_get_avatar_url(self) -> None:
|
||||
res = self._get_avatar_url()
|
||||
async def test_get_avatar_url(self) -> None:
|
||||
res = await self._get_avatar_url()
|
||||
self.assertIsNone(res)
|
||||
|
||||
def test_set_avatar_url(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_set_avatar_url(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
"/profile/%s/avatar_url" % (self.owner,),
|
||||
content={"avatar_url": "http://my.server/pic.gif"},
|
||||
@@ -140,20 +140,20 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
res = self._get_avatar_url()
|
||||
res = await self._get_avatar_url()
|
||||
self.assertEqual(res, "http://my.server/pic.gif")
|
||||
|
||||
def test_set_avatar_url_noauth(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_set_avatar_url_noauth(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
"/profile/%s/avatar_url" % (self.owner,),
|
||||
content={"avatar_url": "http://my.server/pic.gif"},
|
||||
)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
|
||||
def test_set_avatar_url_too_long(self) -> None:
|
||||
async def test_set_avatar_url_too_long(self) -> None:
|
||||
"""Attempts to set a stupid avatar_url should get a 400"""
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
"/profile/%s/avatar_url" % (self.owner,),
|
||||
content={"avatar_url": "http://my.server/pic.gif" * 100},
|
||||
@@ -161,15 +161,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
|
||||
res = self._get_avatar_url()
|
||||
res = await self._get_avatar_url()
|
||||
self.assertIsNone(res)
|
||||
|
||||
def test_get_avatar_url_other(self) -> None:
|
||||
res = self._get_avatar_url(self.other)
|
||||
async def test_get_avatar_url_other(self) -> None:
|
||||
res = await self._get_avatar_url(self.other)
|
||||
self.assertIsNone(res)
|
||||
|
||||
def test_set_avatar_url_other(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_set_avatar_url_other(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
"/profile/%s/avatar_url" % (self.other,),
|
||||
content={"avatar_url": "http://my.server/pic.gif"},
|
||||
@@ -177,8 +177,8 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
|
||||
def _get_displayname(self, name: str | None = None) -> str | None:
|
||||
channel = self.make_request(
|
||||
async def _get_displayname(self, name: str | None = None) -> str | None:
|
||||
channel = await self.make_request(
|
||||
"GET", "/profile/%s/displayname" % (name or self.owner,)
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
@@ -187,8 +187,8 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
# https://github.com/matrix-org/synapse/issues/13137.
|
||||
return channel.json_body.get("displayname")
|
||||
|
||||
def _get_avatar_url(self, name: str | None = None) -> str | None:
|
||||
channel = self.make_request(
|
||||
async def _get_avatar_url(self, name: str | None = None) -> str | None:
|
||||
channel = await self.make_request(
|
||||
"GET", "/profile/%s/avatar_url" % (name or self.owner,)
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
@@ -198,18 +198,18 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
return channel.json_body.get("avatar_url")
|
||||
|
||||
@unittest.override_config({"max_avatar_size": 50})
|
||||
def test_avatar_size_limit_global(self) -> None:
|
||||
async def test_avatar_size_limit_global(self) -> None:
|
||||
"""Tests that the maximum size limit for avatars is enforced when updating a
|
||||
global profile.
|
||||
"""
|
||||
self._setup_local_files(
|
||||
await self._setup_local_files(
|
||||
{
|
||||
"small": {"size": 40},
|
||||
"big": {"size": 60},
|
||||
}
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/avatar_url",
|
||||
content={"avatar_url": "mxc://test/big"},
|
||||
@@ -220,7 +220,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/avatar_url",
|
||||
content={"avatar_url": "mxc://test/small"},
|
||||
@@ -229,20 +229,20 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
@unittest.override_config({"max_avatar_size": 50})
|
||||
def test_avatar_size_limit_per_room(self) -> None:
|
||||
async def test_avatar_size_limit_per_room(self) -> None:
|
||||
"""Tests that the maximum size limit for avatars is enforced when updating a
|
||||
per-room profile.
|
||||
"""
|
||||
self._setup_local_files(
|
||||
await self._setup_local_files(
|
||||
{
|
||||
"small": {"size": 40},
|
||||
"big": {"size": 60},
|
||||
}
|
||||
)
|
||||
|
||||
room_id = self.helper.create_room_as(tok=self.owner_tok)
|
||||
room_id = await self.helper.create_room_as(tok=self.owner_tok)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
|
||||
content={"membership": "join", "avatar_url": "mxc://test/big"},
|
||||
@@ -253,7 +253,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
|
||||
content={"membership": "join", "avatar_url": "mxc://test/small"},
|
||||
@@ -262,18 +262,18 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
|
||||
def test_avatar_allowed_mime_type_global(self) -> None:
|
||||
async def test_avatar_allowed_mime_type_global(self) -> None:
|
||||
"""Tests that the MIME type whitelist for avatars is enforced when updating a
|
||||
global profile.
|
||||
"""
|
||||
self._setup_local_files(
|
||||
await self._setup_local_files(
|
||||
{
|
||||
"good": {"mimetype": "image/png"},
|
||||
"bad": {"mimetype": "application/octet-stream"},
|
||||
}
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/avatar_url",
|
||||
content={"avatar_url": "mxc://test/bad"},
|
||||
@@ -284,7 +284,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/avatar_url",
|
||||
content={"avatar_url": "mxc://test/good"},
|
||||
@@ -293,20 +293,20 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
|
||||
def test_avatar_allowed_mime_type_per_room(self) -> None:
|
||||
async def test_avatar_allowed_mime_type_per_room(self) -> None:
|
||||
"""Tests that the MIME type whitelist for avatars is enforced when updating a
|
||||
per-room profile.
|
||||
"""
|
||||
self._setup_local_files(
|
||||
await self._setup_local_files(
|
||||
{
|
||||
"good": {"mimetype": "image/png"},
|
||||
"bad": {"mimetype": "application/octet-stream"},
|
||||
}
|
||||
)
|
||||
|
||||
room_id = self.helper.create_room_as(tok=self.owner_tok)
|
||||
room_id = await self.helper.create_room_as(tok=self.owner_tok)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
|
||||
content={"membership": "join", "avatar_url": "mxc://test/bad"},
|
||||
@@ -317,7 +317,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
|
||||
content={"membership": "join", "avatar_url": "mxc://test/good"},
|
||||
@@ -328,12 +328,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
@unittest.override_config(
|
||||
{"experimental_features": {"msc4069_profile_inhibit_propagation": True}}
|
||||
)
|
||||
def test_msc4069_inhibit_propagation(self) -> None:
|
||||
async def test_msc4069_inhibit_propagation(self) -> None:
|
||||
"""Tests to ensure profile update propagation can be inhibited."""
|
||||
for prop in ["avatar_url", "displayname"]:
|
||||
room_id = self.helper.create_room_as(tok=self.owner_tok)
|
||||
room_id = await self.helper.create_room_as(tok=self.owner_tok)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
|
||||
content={"membership": "join", prop: "mxc://my.server/existing"},
|
||||
@@ -341,7 +341,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/{prop}?org.matrix.msc4069.propagate=false",
|
||||
content={prop: "http://my.server/pic.gif"},
|
||||
@@ -350,13 +350,13 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
res = (
|
||||
self._get_avatar_url()
|
||||
await self._get_avatar_url()
|
||||
if prop == "avatar_url"
|
||||
else self._get_displayname()
|
||||
else await self._get_displayname()
|
||||
)
|
||||
self.assertEqual(res, "http://my.server/pic.gif")
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
|
||||
access_token=self.owner_tok,
|
||||
@@ -364,14 +364,14 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual(channel.json_body.get(prop), "mxc://my.server/existing")
|
||||
|
||||
def test_msc4069_inhibit_propagation_disabled(self) -> None:
|
||||
async def test_msc4069_inhibit_propagation_disabled(self) -> None:
|
||||
"""Tests to ensure profile update propagation inhibit flags are ignored when the
|
||||
experimental flag is not enabled.
|
||||
"""
|
||||
for prop in ["avatar_url", "displayname"]:
|
||||
room_id = self.helper.create_room_as(tok=self.owner_tok)
|
||||
room_id = await self.helper.create_room_as(tok=self.owner_tok)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
|
||||
content={"membership": "join", prop: "mxc://my.server/existing"},
|
||||
@@ -379,7 +379,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/{prop}?org.matrix.msc4069.propagate=false",
|
||||
content={prop: "http://my.server/pic.gif"},
|
||||
@@ -387,14 +387,16 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
await self._wait_for_profile_propagation()
|
||||
|
||||
res = (
|
||||
self._get_avatar_url()
|
||||
await self._get_avatar_url()
|
||||
if prop == "avatar_url"
|
||||
else self._get_displayname()
|
||||
else await self._get_displayname()
|
||||
)
|
||||
self.assertEqual(res, "http://my.server/pic.gif")
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
|
||||
access_token=self.owner_tok,
|
||||
@@ -405,12 +407,25 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
# isn't enabled.
|
||||
self.assertEqual(channel.json_body.get(prop), "http://my.server/pic.gif")
|
||||
|
||||
def test_msc4069_inhibit_propagation_default(self) -> None:
|
||||
async def _wait_for_profile_propagation(self) -> None:
|
||||
"""Wait for the background task that propagates profile changes to rooms."""
|
||||
import asyncio
|
||||
from synapse.util.task_scheduler import TaskStatus
|
||||
for _ in range(50):
|
||||
tasks = await self.hs.get_task_scheduler().get_tasks(
|
||||
actions=["update_join_states"],
|
||||
statuses=[TaskStatus.ACTIVE, TaskStatus.SCHEDULED],
|
||||
)
|
||||
if not tasks:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
async def test_msc4069_inhibit_propagation_default(self) -> None:
|
||||
"""Tests to ensure profile update propagation happens by default."""
|
||||
for prop in ["avatar_url", "displayname"]:
|
||||
room_id = self.helper.create_room_as(tok=self.owner_tok)
|
||||
room_id = await self.helper.create_room_as(tok=self.owner_tok)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
|
||||
content={"membership": "join", prop: "mxc://my.server/existing"},
|
||||
@@ -418,7 +433,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/{prop}",
|
||||
content={prop: "http://my.server/pic.gif"},
|
||||
@@ -426,14 +441,16 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
await self._wait_for_profile_propagation()
|
||||
|
||||
res = (
|
||||
self._get_avatar_url()
|
||||
await self._get_avatar_url()
|
||||
if prop == "avatar_url"
|
||||
else self._get_displayname()
|
||||
else await self._get_displayname()
|
||||
)
|
||||
self.assertEqual(res, "http://my.server/pic.gif")
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
|
||||
access_token=self.owner_tok,
|
||||
@@ -447,12 +464,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
@unittest.override_config(
|
||||
{"experimental_features": {"msc4069_profile_inhibit_propagation": True}}
|
||||
)
|
||||
def test_msc4069_inhibit_propagation_like_default(self) -> None:
|
||||
async def test_msc4069_inhibit_propagation_like_default(self) -> None:
|
||||
"""Tests to ensure clients can request explicit profile propagation."""
|
||||
for prop in ["avatar_url", "displayname"]:
|
||||
room_id = self.helper.create_room_as(tok=self.owner_tok)
|
||||
room_id = await self.helper.create_room_as(tok=self.owner_tok)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
|
||||
content={"membership": "join", prop: "mxc://my.server/existing"},
|
||||
@@ -460,7 +477,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/{prop}?org.matrix.msc4069.propagate=true",
|
||||
content={prop: "http://my.server/pic.gif"},
|
||||
@@ -468,14 +485,16 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
await self._wait_for_profile_propagation()
|
||||
|
||||
res = (
|
||||
self._get_avatar_url()
|
||||
await self._get_avatar_url()
|
||||
if prop == "avatar_url"
|
||||
else self._get_displayname()
|
||||
else await self._get_displayname()
|
||||
)
|
||||
self.assertEqual(res, "http://my.server/pic.gif")
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/rooms/{room_id}/state/m.room.member/{self.owner}",
|
||||
access_token=self.owner_tok,
|
||||
@@ -485,32 +504,32 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
# The client requested ?propagate=true, so it should have happened.
|
||||
self.assertEqual(channel.json_body.get(prop), "http://my.server/pic.gif")
|
||||
|
||||
def test_get_missing_custom_field(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_get_missing_custom_field(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||
|
||||
def test_get_missing_custom_field_invalid_field_name(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_get_missing_custom_field_invalid_field_name(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/[custom_field]",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
def test_get_custom_field_rejects_bad_username(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_get_custom_field_rejects_bad_username(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{urllib.parse.quote('@alice:')}/custom_field",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
def test_set_custom_field(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_set_custom_field(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
|
||||
content={"custom_field": "test"},
|
||||
@@ -518,7 +537,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
|
||||
)
|
||||
@@ -526,7 +545,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.json_body, {"custom_field": "test"})
|
||||
|
||||
# Overwriting the field should work.
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
|
||||
content={"custom_field": "new_Value"},
|
||||
@@ -534,7 +553,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
|
||||
)
|
||||
@@ -542,7 +561,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.json_body, {"custom_field": "new_Value"})
|
||||
|
||||
# Deleting the field should work.
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"DELETE",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
|
||||
content={},
|
||||
@@ -550,14 +569,14 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||
|
||||
def test_non_string(self) -> None:
|
||||
async def test_non_string(self) -> None:
|
||||
"""Non-string fields are supported for custom fields."""
|
||||
fields = {
|
||||
"bool_field": True,
|
||||
@@ -568,7 +587,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
}
|
||||
|
||||
for key, value in fields.items():
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/{key}",
|
||||
content={key: value},
|
||||
@@ -576,7 +595,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{self.owner}",
|
||||
)
|
||||
@@ -585,15 +604,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Check getting individual fields works.
|
||||
for key, value in fields.items():
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/{key}",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.json_body, {key: value})
|
||||
|
||||
def test_set_custom_field_noauth(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_set_custom_field_noauth(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
|
||||
content={"custom_field": "test"},
|
||||
@@ -601,12 +620,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_TOKEN)
|
||||
|
||||
def test_set_custom_field_size(self) -> None:
|
||||
async def test_set_custom_field_size(self) -> None:
|
||||
"""
|
||||
Attempts to set a custom field name that is too long should get a 400 error.
|
||||
"""
|
||||
# Key is missing.
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/",
|
||||
content={"": "test"},
|
||||
@@ -617,7 +636,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Single key is too large.
|
||||
key = "c" * 500
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/{key}",
|
||||
content={key: "test"},
|
||||
@@ -626,7 +645,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"DELETE",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/{key}",
|
||||
content={key: "test"},
|
||||
@@ -636,7 +655,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
|
||||
|
||||
# Key doesn't match body.
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/custom_field",
|
||||
content={"diff_key": "test"},
|
||||
@@ -645,7 +664,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
|
||||
|
||||
def test_set_custom_field_profile_too_long(self) -> None:
|
||||
async def test_set_custom_field_profile_too_long(self) -> None:
|
||||
"""
|
||||
Attempts to set a custom field that would push the overall profile too large.
|
||||
"""
|
||||
@@ -665,7 +684,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
# 2 braces, 1 comma
|
||||
# 3 + 21 + 65498 = 65522 < 65536.
|
||||
key = "a"
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/{key}",
|
||||
content={key: "a" * 65498},
|
||||
@@ -674,7 +693,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# Get the entire profile.
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{self.owner}",
|
||||
access_token=self.owner_tok,
|
||||
@@ -693,7 +712,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# The next one should fail, note the value has a (JSON) length of 2.
|
||||
key = "b"
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/{key}",
|
||||
content={key: "1" + "a" * ADDITIONAL_CHARS},
|
||||
@@ -703,7 +722,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
|
||||
|
||||
# Setting an avatar or (longer) display name should not work.
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/displayname",
|
||||
content={"displayname": "owner12345678" + "a" * ADDITIONAL_CHARS},
|
||||
@@ -712,7 +731,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/avatar_url",
|
||||
content={"avatar_url": "mxc://foo/bar"},
|
||||
@@ -723,7 +742,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Removing a single byte should work.
|
||||
key = "b"
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/{key}",
|
||||
content={key: "" + "a" * ADDITIONAL_CHARS},
|
||||
@@ -733,7 +752,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Finally, setting a field that already exists to a value that is <= in length should work.
|
||||
key = "a"
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/{key}",
|
||||
content={key: ""},
|
||||
@@ -743,8 +762,8 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
finally:
|
||||
sql_logger.disabled = sql_logger_was_disabled
|
||||
|
||||
def test_set_custom_field_displayname(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_set_custom_field_displayname(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/displayname",
|
||||
content={"displayname": "test"},
|
||||
@@ -752,11 +771,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
displayname = self._get_displayname()
|
||||
displayname = await self._get_displayname()
|
||||
self.assertEqual(displayname, "test")
|
||||
|
||||
def test_set_custom_field_avatar_url(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_set_custom_field_avatar_url(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.owner}/avatar_url",
|
||||
content={"avatar_url": "mxc://test/good"},
|
||||
@@ -764,12 +783,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
avatar_url = self._get_avatar_url()
|
||||
avatar_url = await self._get_avatar_url()
|
||||
self.assertEqual(avatar_url, "mxc://test/good")
|
||||
|
||||
def test_set_custom_field_other(self) -> None:
|
||||
async def test_set_custom_field_other(self) -> None:
|
||||
"""Setting someone else's profile field should fail"""
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/v3/profile/{self.other}/custom_field",
|
||||
content={"custom_field": "test"},
|
||||
@@ -778,7 +797,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 403, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||
|
||||
def _setup_local_files(self, names_and_props: dict[str, dict[str, Any]]) -> None:
|
||||
async def _setup_local_files(self, names_and_props: dict[str, dict[str, Any]]) -> None:
|
||||
"""Stores metadata about files in the database.
|
||||
|
||||
Args:
|
||||
@@ -790,7 +809,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
for name, props in names_and_props.items():
|
||||
self.get_success(
|
||||
await self.get_success(
|
||||
store.store_local_media(
|
||||
media_id=name,
|
||||
media_type=props.get("mimetype", "image/png"),
|
||||
@@ -810,68 +829,68 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
config["require_auth_for_profile_requests"] = True
|
||||
config["limit_profile_requests_to_users_who_share_rooms"] = True
|
||||
self.hs = self.setup_test_homeserver(config=config)
|
||||
self.hs = await self.setup_test_homeserver(config=config)
|
||||
|
||||
return self.hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# User owning the requested profile.
|
||||
self.owner = self.register_user("owner", "pass")
|
||||
self.owner_tok = self.login("owner", "pass")
|
||||
self.owner = await self.register_user("owner", "pass")
|
||||
self.owner_tok = await self.login("owner", "pass")
|
||||
self.profile_url = "/profile/%s" % (self.owner)
|
||||
|
||||
# User requesting the profile.
|
||||
self.requester = self.register_user("requester", "pass")
|
||||
self.requester_tok = self.login("requester", "pass")
|
||||
self.requester = await self.register_user("requester", "pass")
|
||||
self.requester_tok = await self.login("requester", "pass")
|
||||
|
||||
self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok)
|
||||
self.room_id = await self.helper.create_room_as(self.owner, tok=self.owner_tok)
|
||||
|
||||
def test_no_auth(self) -> None:
|
||||
self.try_fetch_profile(401)
|
||||
async def test_no_auth(self) -> None:
|
||||
await self.try_fetch_profile(401)
|
||||
|
||||
def test_not_in_shared_room(self) -> None:
|
||||
self.ensure_requester_left_room()
|
||||
async def test_not_in_shared_room(self) -> None:
|
||||
await self.ensure_requester_left_room()
|
||||
|
||||
self.try_fetch_profile(403, access_token=self.requester_tok)
|
||||
await self.try_fetch_profile(403, access_token=self.requester_tok)
|
||||
|
||||
def test_in_shared_room(self) -> None:
|
||||
self.ensure_requester_left_room()
|
||||
async def test_in_shared_room(self) -> None:
|
||||
await self.ensure_requester_left_room()
|
||||
|
||||
self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok)
|
||||
await self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok)
|
||||
|
||||
self.try_fetch_profile(200, self.requester_tok)
|
||||
await self.try_fetch_profile(200, self.requester_tok)
|
||||
|
||||
def try_fetch_profile(
|
||||
async def try_fetch_profile(
|
||||
self, expected_code: int, access_token: str | None = None
|
||||
) -> None:
|
||||
self.request_profile(expected_code, access_token=access_token)
|
||||
await self.request_profile(expected_code, access_token=access_token)
|
||||
|
||||
self.request_profile(
|
||||
await self.request_profile(
|
||||
expected_code, url_suffix="/displayname", access_token=access_token
|
||||
)
|
||||
|
||||
self.request_profile(
|
||||
await self.request_profile(
|
||||
expected_code, url_suffix="/avatar_url", access_token=access_token
|
||||
)
|
||||
|
||||
def request_profile(
|
||||
async def request_profile(
|
||||
self,
|
||||
expected_code: int,
|
||||
url_suffix: str = "",
|
||||
access_token: str | None = None,
|
||||
) -> None:
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET", self.profile_url + url_suffix, access_token=access_token
|
||||
)
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
|
||||
def ensure_requester_left_room(self) -> None:
|
||||
async def ensure_requester_left_room(self) -> None:
|
||||
try:
|
||||
self.helper.leave(
|
||||
await self.helper.leave(
|
||||
room=self.room_id, user=self.requester, tok=self.requester_tok
|
||||
)
|
||||
except AssertionError:
|
||||
@@ -888,36 +907,36 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
|
||||
profile.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
config["require_auth_for_profile_requests"] = True
|
||||
config["limit_profile_requests_to_users_who_share_rooms"] = True
|
||||
self.hs = self.setup_test_homeserver(config=config)
|
||||
self.hs = await self.setup_test_homeserver(config=config)
|
||||
|
||||
return self.hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# User requesting the profile.
|
||||
self.requester = self.register_user("requester", "pass")
|
||||
self.requester_tok = self.login("requester", "pass")
|
||||
self.requester = await self.register_user("requester", "pass")
|
||||
self.requester_tok = await self.login("requester", "pass")
|
||||
|
||||
def test_can_lookup_own_profile(self) -> None:
|
||||
async def test_can_lookup_own_profile(self) -> None:
|
||||
"""Tests that a user can lookup their own profile without having to be in a room
|
||||
if 'require_auth_for_profile_requests' is set to true in the server's config.
|
||||
"""
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET", "/profile/" + self.requester, access_token=self.requester_tok
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
"/profile/" + self.requester + "/displayname",
|
||||
access_token=self.requester_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"GET",
|
||||
"/profile/" + self.requester + "/avatar_url",
|
||||
access_token=self.requester_tok,
|
||||
|
||||
@@ -43,7 +43,7 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
|
||||
admin.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
|
||||
# merge this default retention config with anything that was specified in
|
||||
@@ -56,27 +56,27 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
|
||||
retention_config.update(config.get("retention", {}))
|
||||
config["retention"] = retention_config
|
||||
|
||||
self.hs = self.setup_test_homeserver(config=config)
|
||||
self.hs = await self.setup_test_homeserver(config=config)
|
||||
|
||||
return self.hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.owner = self.register_user("owner", "pass")
|
||||
self.owner_tok = self.login("owner", "pass")
|
||||
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.owner = await self.register_user("owner", "pass")
|
||||
self.owner_tok = await self.login("owner", "pass")
|
||||
self.store = self.hs.get_datastores().main
|
||||
self.clock = self.hs.get_clock()
|
||||
|
||||
def test_send_read_marker(self) -> None:
|
||||
room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok)
|
||||
async def test_send_read_marker(self) -> None:
|
||||
room_id = await self.helper.create_room_as(self.owner, tok=self.owner_tok)
|
||||
|
||||
def send_message() -> str:
|
||||
res = self.helper.send(room_id=room_id, body="1", tok=self.owner_tok)
|
||||
async def send_message() -> str:
|
||||
res = await self.helper.send(room_id=room_id, body="1", tok=self.owner_tok)
|
||||
return res["event_id"]
|
||||
|
||||
# Test setting the read marker on the room
|
||||
event_id_1 = send_message()
|
||||
event_id_1 = await send_message()
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"POST",
|
||||
f"/rooms/{room_id}/read_markers",
|
||||
content={
|
||||
@@ -87,8 +87,8 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# Test moving the read marker to a newer event
|
||||
event_id_2 = send_message()
|
||||
channel = self.make_request(
|
||||
event_id_2 = await send_message()
|
||||
channel = await self.make_request(
|
||||
"POST",
|
||||
f"/rooms/{room_id}/read_markers",
|
||||
content={
|
||||
@@ -98,30 +98,30 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
def test_send_read_marker_missing_previous_event(self) -> None:
|
||||
async def test_send_read_marker_missing_previous_event(self) -> None:
|
||||
"""
|
||||
Test moving a read marker from an event that previously existed but was
|
||||
later removed due to retention rules.
|
||||
"""
|
||||
|
||||
room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok)
|
||||
room_id = await self.helper.create_room_as(self.owner, tok=self.owner_tok)
|
||||
|
||||
# Set retention rule on the room so we remove old events to test this case
|
||||
self.helper.send_state(
|
||||
await self.helper.send_state(
|
||||
room_id=room_id,
|
||||
event_type=EventTypes.Retention,
|
||||
body={"max_lifetime": ONE_DAY_MS},
|
||||
tok=self.owner_tok,
|
||||
)
|
||||
|
||||
def send_message() -> str:
|
||||
res = self.helper.send(room_id=room_id, body="1", tok=self.owner_tok)
|
||||
async def send_message() -> str:
|
||||
res = await self.helper.send(room_id=room_id, body="1", tok=self.owner_tok)
|
||||
return res["event_id"]
|
||||
|
||||
# Test setting the read marker on the room
|
||||
event_id_1 = send_message()
|
||||
event_id_1 = await send_message()
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"POST",
|
||||
f"/rooms/{room_id}/read_markers",
|
||||
content={
|
||||
@@ -131,16 +131,16 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
# Send a second message (retention will not remove the latest event ever)
|
||||
send_message()
|
||||
await send_message()
|
||||
# And then advance so retention rules remove the first event (where the marker is)
|
||||
self.reactor.advance(ONE_DAY_MS * 2 / 1000)
|
||||
await self.pump(ONE_DAY_MS * 2 / 1000)
|
||||
|
||||
event = self.get_success(self.store.get_event(event_id_1, allow_none=True))
|
||||
event = await self.get_success(self.store.get_event(event_id_1, allow_none=True))
|
||||
assert event is None
|
||||
|
||||
# Test moving the read marker to a newer event
|
||||
event_id_2 = send_message()
|
||||
channel = self.make_request(
|
||||
event_id_2 = await send_message()
|
||||
channel = await self.make_request(
|
||||
"POST",
|
||||
f"/rooms/{room_id}/read_markers",
|
||||
content={
|
||||
|
||||
@@ -42,18 +42,18 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
||||
user = UserID.from_string(user_id)
|
||||
servlets = [room.register_servlets]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
hs = self.setup_test_homeserver("red")
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
hs = await self.setup_test_homeserver("red")
|
||||
self.event_source = hs.get_event_sources().sources.typing
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
async def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.room_id = await self.helper.create_room_as(self.user_id)
|
||||
# Need another user to make notifications actually work
|
||||
self.helper.join(self.room_id, user="@jim:red")
|
||||
await self.helper.join(self.room_id, user="@jim:red")
|
||||
|
||||
def test_set_typing(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_set_typing(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
||||
b'{"typing": true, "timeout": 30000}',
|
||||
@@ -61,7 +61,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(200, channel.code)
|
||||
|
||||
self.assertEqual(self.event_source.get_current_key(), 1)
|
||||
events = self.get_success(
|
||||
events = await self.get_success(
|
||||
self.event_source.get_new_events(
|
||||
user=UserID.from_string(self.user_id),
|
||||
from_key=0,
|
||||
@@ -82,16 +82,16 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_set_not_typing(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_set_not_typing(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
||||
b'{"typing": false}',
|
||||
)
|
||||
self.assertEqual(200, channel.code)
|
||||
|
||||
def test_typing_timeout(self) -> None:
|
||||
channel = self.make_request(
|
||||
async def test_typing_timeout(self) -> None:
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
||||
b'{"typing": true, "timeout": 30000}',
|
||||
@@ -100,11 +100,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual(self.event_source.get_current_key(), 1)
|
||||
|
||||
self.reactor.advance(36)
|
||||
await self.pump(36)
|
||||
|
||||
self.assertEqual(self.event_source.get_current_key(), 2)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"PUT",
|
||||
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
||||
b'{"typing": true, "timeout": 30000}',
|
||||
|
||||
+16
-11
@@ -202,13 +202,11 @@ class TestCase(_stdlib_unittest.IsolatedAsyncioTestCase):
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# if we're not starting in the sentinel logcontext, then to be honest
|
||||
# all future bets are off.
|
||||
# If we're not starting in the sentinel logcontext, reset it.
|
||||
# In asyncio, cancelled background tasks may leave a non-sentinel
|
||||
# context during event loop shutdown between tests.
|
||||
if current_context():
|
||||
self.fail(
|
||||
"Test starting with non-sentinel logging context %s"
|
||||
% (current_context(),)
|
||||
)
|
||||
set_current_context(SENTINEL_CONTEXT)
|
||||
|
||||
# Disable GC for duration of test (re-enabled in tearDown).
|
||||
gc.disable()
|
||||
@@ -585,9 +583,11 @@ class HomeserverTestCase(TestCase):
|
||||
"""
|
||||
if hasattr(self, 'hs') and self.hs is not None:
|
||||
self.hs.get_clock().shutdown()
|
||||
# Give event loop time for cancellation to propagate
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
# Give event loop time for cancellation callbacks to propagate.
|
||||
# Multiple yields are needed because cancellation of a task may
|
||||
# trigger cleanup in other tasks (e.g. context managers exiting).
|
||||
for _ in range(5):
|
||||
await asyncio.sleep(0)
|
||||
# Ensure we're back in the sentinel context
|
||||
set_current_context(SENTINEL_CONTEXT)
|
||||
|
||||
@@ -819,11 +819,16 @@ class HomeserverTestCase(TestCase):
|
||||
|
||||
``reactor.advance()`` delegates to ``clock.advance()``, so calling
|
||||
either one advances the same fake-time source.
|
||||
|
||||
After advancing, yields multiple times to allow background tasks
|
||||
(including those running in executor threads) to complete.
|
||||
"""
|
||||
# Advance fake time (fires pending sleeps and callFromThread)
|
||||
self.reactor.advance(by)
|
||||
# Yield to the event loop so callbacks can run
|
||||
await asyncio.sleep(0)
|
||||
# Yield to the event loop multiple times so background tasks
|
||||
# (including DB operations in executor threads) can complete.
|
||||
for _ in range(20):
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
|
||||
"""Await an awaitable, optionally advancing fake time first."""
|
||||
|
||||
Reference in New Issue
Block a user