mirror of
https://github.com/element-hq/synapse.git
synced 2026-05-25 22:54:07 +00:00
Add a test for the spam_checker_spammy flag
This commit is contained in:
@@ -16,18 +16,27 @@ from typing import Literal
|
||||
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EventContentFields, EventTypes
|
||||
from synapse.api.constants import (
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
Membership,
|
||||
)
|
||||
from synapse.config.server import DEFAULT_ROOM_VERSION
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.module_api import EventBase
|
||||
from synapse.rest import admin, login, room, room_upgrade_rest_servlet
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import Codes, JsonDict
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeChannel
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
class SpamCheckerTestCase(HomeserverTestCase):
|
||||
"""Tests for the spam checker module API."""
|
||||
|
||||
servlets = [
|
||||
room.register_servlets,
|
||||
admin.register_servlets,
|
||||
@@ -284,3 +293,173 @@ class SpamCheckerTestCase(HomeserverTestCase):
|
||||
|
||||
self.assertEqual(channel.code, 403)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||
|
||||
|
||||
class FederatedEventSpamCheckMetadataTestCase(unittest.FederatingHomeserverTestCase):
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
super().prepare(reactor, clock, hs)
|
||||
self._module_api = hs.get_module_api()
|
||||
self._store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._federation_event_handler = hs.get_federation_event_handler()
|
||||
self._federation_server = hs.get_federation_server()
|
||||
self._state_handler = hs.get_state_handler()
|
||||
self._persistence_controller = hs.get_storage_controllers().persistence
|
||||
|
||||
# Create a room
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
self.room_id = self.helper.create_room_as(
|
||||
user1_id, tok=user1_tok, is_public=True
|
||||
)
|
||||
|
||||
# Prepare a join for the 'remote' user
|
||||
state_map = self.get_success(
|
||||
self._storage_controllers.state.get_current_state(self.room_id)
|
||||
)
|
||||
forward_extremity_event_ids = self.get_success(
|
||||
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
|
||||
)
|
||||
self.remote_user_id = f"@remoteuser:{self.OTHER_SERVER_NAME}"
|
||||
self.remote_user_join_event = make_event_from_dict(
|
||||
self.add_hashes_and_signatures_from_other_server(
|
||||
{
|
||||
"room_id": self.room_id,
|
||||
"sender": self.remote_user_id,
|
||||
"state_key": self.remote_user_id,
|
||||
"depth": 1000,
|
||||
"origin_server_ts": 1,
|
||||
"type": EventTypes.Member,
|
||||
"content": {"membership": Membership.JOIN},
|
||||
"auth_events": [
|
||||
state_map[(EventTypes.Create, "")].event_id,
|
||||
state_map[(EventTypes.JoinRules, "")].event_id,
|
||||
],
|
||||
"prev_events": list(forward_extremity_event_ids),
|
||||
}
|
||||
),
|
||||
room_version=self.hs.config.server.default_room_version,
|
||||
)
|
||||
|
||||
# Send the join
|
||||
self.get_success(
|
||||
self._federation_event_handler.on_receive_pdu(
|
||||
self.OTHER_SERVER_NAME, self.remote_user_join_event
|
||||
)
|
||||
)
|
||||
|
||||
# Check the join made it to the 'local' view of the room
|
||||
self.assertEqual(
|
||||
self.get_success(self._store.get_latest_event_ids_in_room(self.room_id)),
|
||||
{self.remote_user_join_event.event_id},
|
||||
)
|
||||
|
||||
def test_federated_events_with_spam_checker_metadata(self) -> None:
|
||||
"""
|
||||
Simulates receiving spammy and non-spammy events over federation,
|
||||
then checks their `spam_checker_spammy` flag is set properly.
|
||||
"""
|
||||
|
||||
async def check_event_for_spam(event: EventBase) -> Literal["NOT_SPAM"] | Codes:
|
||||
if event.type == EventTypes.Message:
|
||||
if "ham" not in event.content["body"]:
|
||||
return Codes.FORBIDDEN
|
||||
return "NOT_SPAM"
|
||||
|
||||
# Register a spam checker callback that only allows messages with 'ham'
|
||||
self._module_api.register_spam_checker_callbacks(
|
||||
check_event_for_spam=check_event_for_spam
|
||||
)
|
||||
|
||||
# Prepare a spammy and a non-spammy event.
|
||||
forward_extremity_event_ids = self.get_success(
|
||||
self._store.get_latest_event_ids_in_room(self.room_id)
|
||||
)
|
||||
state_map = self.get_success(
|
||||
self._storage_controllers.state.get_current_state(self.room_id)
|
||||
)
|
||||
spammy_event = make_event_from_dict(
|
||||
self.add_hashes_and_signatures_from_other_server(
|
||||
{
|
||||
"room_id": self.room_id,
|
||||
"sender": self.remote_user_id,
|
||||
"depth": 2000,
|
||||
"origin_server_ts": 2,
|
||||
"type": EventTypes.Message,
|
||||
"content": {"body": "this is spam", "msgtype": "m.text"},
|
||||
"auth_events": [
|
||||
state_map[(EventTypes.Create, "")].event_id,
|
||||
state_map[(EventTypes.JoinRules, "")].event_id,
|
||||
state_map[(EventTypes.Member, self.remote_user_id)].event_id,
|
||||
],
|
||||
"prev_events": list(forward_extremity_event_ids),
|
||||
}
|
||||
),
|
||||
room_version=self.hs.config.server.default_room_version,
|
||||
)
|
||||
non_spammy_event = make_event_from_dict(
|
||||
self.add_hashes_and_signatures_from_other_server(
|
||||
{
|
||||
"room_id": self.room_id,
|
||||
"sender": self.remote_user_id,
|
||||
"depth": 2000,
|
||||
"origin_server_ts": 2,
|
||||
"type": EventTypes.Message,
|
||||
"content": {"body": "delicious ham", "msgtype": "m.text"},
|
||||
"auth_events": [
|
||||
state_map[(EventTypes.Create, "")].event_id,
|
||||
state_map[(EventTypes.JoinRules, "")].event_id,
|
||||
state_map[(EventTypes.Member, self.remote_user_id)].event_id,
|
||||
],
|
||||
"prev_events": list(forward_extremity_event_ids),
|
||||
}
|
||||
),
|
||||
room_version=self.hs.config.server.default_room_version,
|
||||
)
|
||||
|
||||
# Receive these events over federation
|
||||
# We need to let the federation server have them because it will
|
||||
# invoke `_check_sigs_and_hash` which invokes the spam checker.
|
||||
self.get_success(
|
||||
self._federation_server._handle_received_pdu(
|
||||
self.OTHER_SERVER_NAME, spammy_event
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
self._federation_server._handle_received_pdu(
|
||||
self.OTHER_SERVER_NAME, non_spammy_event
|
||||
)
|
||||
)
|
||||
|
||||
# Retrieve the events from the database
|
||||
retrieved_spammy_event = self.get_success(
|
||||
self._store.get_event(spammy_event.event_id, allow_rejected=True)
|
||||
)
|
||||
retrieved_non_spammy_event = self.get_success(
|
||||
self._store.get_event(non_spammy_event.event_id, allow_rejected=True)
|
||||
)
|
||||
|
||||
# Assert the spammy flags (and soft-failed flags, for good measure) are set properly
|
||||
self.assertTrue(
|
||||
retrieved_spammy_event.internal_metadata.spam_checker_spammy,
|
||||
"Spammy inbound event should be marked as spam_checker_spammy!",
|
||||
)
|
||||
self.assertTrue(
|
||||
retrieved_spammy_event.internal_metadata.is_soft_failed(),
|
||||
"Spammy inbound event should be soft-failed.",
|
||||
)
|
||||
|
||||
self.assertFalse(
|
||||
retrieved_non_spammy_event.internal_metadata.spam_checker_spammy,
|
||||
"Non-spammy inbound event should not be marked as spam_checker_spammy!",
|
||||
)
|
||||
self.assertFalse(
|
||||
retrieved_non_spammy_event.internal_metadata.is_soft_failed(),
|
||||
"Non-spammy inbound event should not be soft-failed.",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user