Add a test for the spam_checker_spammy flag

This commit is contained in:
Olivier 'reivilibre
2026-02-23 14:38:38 +00:00
parent be62d4cdad
commit 96059cbbfe
+180 -1
View File
@@ -16,18 +16,27 @@ from typing import Literal
from twisted.internet.testing import MemoryReactor
from synapse.api.constants import EventContentFields, EventTypes
from synapse.api.constants import (
EventContentFields,
EventTypes,
Membership,
)
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.events import make_event_from_dict
from synapse.module_api import EventBase
from synapse.rest import admin, login, room, room_upgrade_rest_servlet
from synapse.server import HomeServer
from synapse.types import Codes, JsonDict
from synapse.util.clock import Clock
from tests import unittest
from tests.server import FakeChannel
from tests.unittest import HomeserverTestCase
class SpamCheckerTestCase(HomeserverTestCase):
"""Tests for the spam checker module API."""
servlets = [
room.register_servlets,
admin.register_servlets,
@@ -284,3 +293,173 @@ class SpamCheckerTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
class FederatedEventSpamCheckMetadataTestCase(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self._module_api = hs.get_module_api()
self._store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._federation_event_handler = hs.get_federation_event_handler()
self._federation_server = hs.get_federation_server()
self._state_handler = hs.get_state_handler()
self._persistence_controller = hs.get_storage_controllers().persistence
# Create a room
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
self.room_id = self.helper.create_room_as(
user1_id, tok=user1_tok, is_public=True
)
# Prepare a join for the 'remote' user
state_map = self.get_success(
self._storage_controllers.state.get_current_state(self.room_id)
)
forward_extremity_event_ids = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)
self.remote_user_id = f"@remoteuser:{self.OTHER_SERVER_NAME}"
self.remote_user_join_event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
{
"room_id": self.room_id,
"sender": self.remote_user_id,
"state_key": self.remote_user_id,
"depth": 1000,
"origin_server_ts": 1,
"type": EventTypes.Member,
"content": {"membership": Membership.JOIN},
"auth_events": [
state_map[(EventTypes.Create, "")].event_id,
state_map[(EventTypes.JoinRules, "")].event_id,
],
"prev_events": list(forward_extremity_event_ids),
}
),
room_version=self.hs.config.server.default_room_version,
)
# Send the join
self.get_success(
self._federation_event_handler.on_receive_pdu(
self.OTHER_SERVER_NAME, self.remote_user_join_event
)
)
# Check the join made it to the 'local' view of the room
self.assertEqual(
self.get_success(self._store.get_latest_event_ids_in_room(self.room_id)),
{self.remote_user_join_event.event_id},
)
def test_federated_events_with_spam_checker_metadata(self) -> None:
"""
Simulates receiving spammy and non-spammy events over federation,
then checks their `spam_checker_spammy` flag is set properly.
"""
async def check_event_for_spam(event: EventBase) -> Literal["NOT_SPAM"] | Codes:
if event.type == EventTypes.Message:
if "ham" not in event.content["body"]:
return Codes.FORBIDDEN
return "NOT_SPAM"
# Register a spam checker callback that only allows messages with 'ham'
self._module_api.register_spam_checker_callbacks(
check_event_for_spam=check_event_for_spam
)
# Prepare a spammy and a non-spammy event.
forward_extremity_event_ids = self.get_success(
self._store.get_latest_event_ids_in_room(self.room_id)
)
state_map = self.get_success(
self._storage_controllers.state.get_current_state(self.room_id)
)
spammy_event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
{
"room_id": self.room_id,
"sender": self.remote_user_id,
"depth": 2000,
"origin_server_ts": 2,
"type": EventTypes.Message,
"content": {"body": "this is spam", "msgtype": "m.text"},
"auth_events": [
state_map[(EventTypes.Create, "")].event_id,
state_map[(EventTypes.JoinRules, "")].event_id,
state_map[(EventTypes.Member, self.remote_user_id)].event_id,
],
"prev_events": list(forward_extremity_event_ids),
}
),
room_version=self.hs.config.server.default_room_version,
)
non_spammy_event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
{
"room_id": self.room_id,
"sender": self.remote_user_id,
"depth": 2000,
"origin_server_ts": 2,
"type": EventTypes.Message,
"content": {"body": "delicious ham", "msgtype": "m.text"},
"auth_events": [
state_map[(EventTypes.Create, "")].event_id,
state_map[(EventTypes.JoinRules, "")].event_id,
state_map[(EventTypes.Member, self.remote_user_id)].event_id,
],
"prev_events": list(forward_extremity_event_ids),
}
),
room_version=self.hs.config.server.default_room_version,
)
# Receive these events over federation
# We need to let the federation server have them because it will
# invoke `_check_sigs_and_hash` which invokes the spam checker.
self.get_success(
self._federation_server._handle_received_pdu(
self.OTHER_SERVER_NAME, spammy_event
)
)
self.get_success(
self._federation_server._handle_received_pdu(
self.OTHER_SERVER_NAME, non_spammy_event
)
)
# Retrieve the events from the database
retrieved_spammy_event = self.get_success(
self._store.get_event(spammy_event.event_id, allow_rejected=True)
)
retrieved_non_spammy_event = self.get_success(
self._store.get_event(non_spammy_event.event_id, allow_rejected=True)
)
# Assert the spammy flags (and soft-failed flags, for good measure) are set properly
self.assertTrue(
retrieved_spammy_event.internal_metadata.spam_checker_spammy,
"Spammy inbound event should be marked as spam_checker_spammy!",
)
self.assertTrue(
retrieved_spammy_event.internal_metadata.is_soft_failed(),
"Spammy inbound event should be soft-failed.",
)
self.assertFalse(
retrieved_non_spammy_event.internal_metadata.spam_checker_spammy,
"Non-spammy inbound event should not be marked as spam_checker_spammy!",
)
self.assertFalse(
retrieved_non_spammy_event.internal_metadata.is_soft_failed(),
"Non-spammy inbound event should not be soft-failed.",
)