diff --git a/tests/module_api/test_spamchecker.py b/tests/module_api/test_spamchecker.py index 42ef969ce0..7b2f778073 100644 --- a/tests/module_api/test_spamchecker.py +++ b/tests/module_api/test_spamchecker.py @@ -16,18 +16,27 @@ from typing import Literal from twisted.internet.testing import MemoryReactor -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, +) from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.events import make_event_from_dict +from synapse.module_api import EventBase from synapse.rest import admin, login, room, room_upgrade_rest_servlet from synapse.server import HomeServer from synapse.types import Codes, JsonDict from synapse.util.clock import Clock +from tests import unittest from tests.server import FakeChannel from tests.unittest import HomeserverTestCase class SpamCheckerTestCase(HomeserverTestCase): + """Tests for the spam checker module API.""" + servlets = [ room.register_servlets, admin.register_servlets, @@ -284,3 +293,173 @@ class SpamCheckerTestCase(HomeserverTestCase): self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) + + +class FederatedEventSpamCheckMetadataTestCase(unittest.FederatingHomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + self._module_api = hs.get_module_api() + self._store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() + self._federation_event_handler = hs.get_federation_event_handler() + self._federation_server = hs.get_federation_server() + self._state_handler = hs.get_state_handler() + self._persistence_controller = hs.get_storage_controllers().persistence + + # Create a room + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + self.room_id = self.helper.create_room_as( + user1_id, tok=user1_tok, is_public=True + ) + + # Prepare a join for the 'remote' user + state_map = self.get_success( + self._storage_controllers.state.get_current_state(self.room_id) + ) + forward_extremity_event_ids = self.get_success( + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) + ) + self.remote_user_id = f"@remoteuser:{self.OTHER_SERVER_NAME}" + self.remote_user_join_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.room_id, + "sender": self.remote_user_id, + "state_key": self.remote_user_id, + "depth": 1000, + "origin_server_ts": 1, + "type": EventTypes.Member, + "content": {"membership": Membership.JOIN}, + "auth_events": [ + state_map[(EventTypes.Create, "")].event_id, + state_map[(EventTypes.JoinRules, "")].event_id, + ], + "prev_events": list(forward_extremity_event_ids), + } + ), + room_version=self.hs.config.server.default_room_version, + ) + + # Send the join + self.get_success( + self._federation_event_handler.on_receive_pdu( + self.OTHER_SERVER_NAME, self.remote_user_join_event + ) + ) + + # Check the join made it to the 'local' view of the room + self.assertEqual( + self.get_success(self._store.get_latest_event_ids_in_room(self.room_id)), + {self.remote_user_join_event.event_id}, + ) + + def test_federated_events_with_spam_checker_metadata(self) -> None: + """ + Simulates receiving spammy and non-spammy events over federation, + then checks their `spam_checker_spammy` flag is set properly. + """ + + async def check_event_for_spam(event: EventBase) -> Literal["NOT_SPAM"] | Codes: + if event.type == EventTypes.Message: + if "ham" not in event.content["body"]: + return Codes.FORBIDDEN + return "NOT_SPAM" + + # Register a spam checker callback that only allows messages with 'ham' + self._module_api.register_spam_checker_callbacks( + check_event_for_spam=check_event_for_spam + ) + + # Prepare a spammy and a non-spammy event. + forward_extremity_event_ids = self.get_success( + self._store.get_latest_event_ids_in_room(self.room_id) + ) + state_map = self.get_success( + self._storage_controllers.state.get_current_state(self.room_id) + ) + spammy_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.room_id, + "sender": self.remote_user_id, + "depth": 2000, + "origin_server_ts": 2, + "type": EventTypes.Message, + "content": {"body": "this is spam", "msgtype": "m.text"}, + "auth_events": [ + state_map[(EventTypes.Create, "")].event_id, + state_map[(EventTypes.JoinRules, "")].event_id, + state_map[(EventTypes.Member, self.remote_user_id)].event_id, + ], + "prev_events": list(forward_extremity_event_ids), + } + ), + room_version=self.hs.config.server.default_room_version, + ) + non_spammy_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.room_id, + "sender": self.remote_user_id, + "depth": 2000, + "origin_server_ts": 2, + "type": EventTypes.Message, + "content": {"body": "delicious ham", "msgtype": "m.text"}, + "auth_events": [ + state_map[(EventTypes.Create, "")].event_id, + state_map[(EventTypes.JoinRules, "")].event_id, + state_map[(EventTypes.Member, self.remote_user_id)].event_id, + ], + "prev_events": list(forward_extremity_event_ids), + } + ), + room_version=self.hs.config.server.default_room_version, + ) + + # Receive these events over federation + # We need to let the federation server have them because it will + # invoke `_check_sigs_and_hash` which invokes the spam checker. + self.get_success( + self._federation_server._handle_received_pdu( + self.OTHER_SERVER_NAME, spammy_event + ) + ) + self.get_success( + self._federation_server._handle_received_pdu( + self.OTHER_SERVER_NAME, non_spammy_event + ) + ) + + # Retrieve the events from the database + retrieved_spammy_event = self.get_success( + self._store.get_event(spammy_event.event_id, allow_rejected=True) + ) + retrieved_non_spammy_event = self.get_success( + self._store.get_event(non_spammy_event.event_id, allow_rejected=True) + ) + + # Assert the spammy flags (and soft-failed flags, for good measure) are set properly + self.assertTrue( + retrieved_spammy_event.internal_metadata.spam_checker_spammy, + "Spammy inbound event should be marked as spam_checker_spammy!", + ) + self.assertTrue( + retrieved_spammy_event.internal_metadata.is_soft_failed(), + "Spammy inbound event should be soft-failed.", + ) + + self.assertFalse( + retrieved_non_spammy_event.internal_metadata.spam_checker_spammy, + "Non-spammy inbound event should not be marked as spam_checker_spammy!", + ) + self.assertFalse( + retrieved_non_spammy_event.internal_metadata.is_soft_failed(), + "Non-spammy inbound event should not be soft-failed.", + )