diff --git a/changelog.d/19424.feature b/changelog.d/19424.feature new file mode 100644 index 0000000000..8f241a87b5 --- /dev/null +++ b/changelog.d/19424.feature @@ -0,0 +1 @@ +Add experimental support for [MSC4242](https://github.com/matrix-org/matrix-spec-proposals/pull/4242): State DAGs. Excludes federation support. \ No newline at end of file diff --git a/changelog.d/19473.misc b/changelog.d/19473.misc new file mode 100644 index 0000000000..596d8a6b26 --- /dev/null +++ b/changelog.d/19473.misc @@ -0,0 +1 @@ +Reduce database disk space usage by pruning old rows from `device_lists_changes_in_room`. diff --git a/changelog.d/19639.bugfix b/changelog.d/19639.bugfix new file mode 100644 index 0000000000..2d2928c1ee --- /dev/null +++ b/changelog.d/19639.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.145 where a non-admin could bypass admin checks for downloading remote quarantined media. This relied on the media already being previously present on the homeserver. \ No newline at end of file diff --git a/changelog.d/19657.feature b/changelog.d/19657.feature new file mode 100644 index 0000000000..f87ef9fed8 --- /dev/null +++ b/changelog.d/19657.feature @@ -0,0 +1,2 @@ +Adds [Admin API](https://element-hq.github.io/synapse/latest/usage/administration/admin_api/index.html) endpoints to +list, fetch and delete user reports. \ No newline at end of file diff --git a/changelog.d/19696.doc b/changelog.d/19696.doc new file mode 100644 index 0000000000..c531359444 --- /dev/null +++ b/changelog.d/19696.doc @@ -0,0 +1 @@ +Update the developer stream docs for creating a new stream to highlight places that require documentation updates. diff --git a/changelog.d/19709.misc b/changelog.d/19709.misc new file mode 100644 index 0000000000..596d8a6b26 --- /dev/null +++ b/changelog.d/19709.misc @@ -0,0 +1 @@ +Reduce database disk space usage by pruning old rows from `device_lists_changes_in_room`. diff --git a/changelog.d/19712.misc b/changelog.d/19712.misc new file mode 100644 index 0000000000..c8fa79bf47 --- /dev/null +++ b/changelog.d/19712.misc @@ -0,0 +1 @@ +Small simplifications to the events class. diff --git a/contrib/grafana/synapse.json b/contrib/grafana/synapse.json index ceacc10369..a67d1aba0b 100644 --- a/contrib/grafana/synapse.json +++ b/contrib/grafana/synapse.json @@ -6809,6 +6809,155 @@ ], "title": "Stale extremity dropping", "type": "timeseries" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "description": "For a given percentage P, the number X where P% of events were persisted to rooms with X state DAG forward extremities or fewer.", + "fieldConfig": { + "defaults": { + "links": [] + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 50 + }, + "id": 181, + "options": { + "alertThreshold": true + }, + "pluginVersion": "9.2.2", + "targets": [ + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "histogram_quantile(0.5, rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "50%", + "refId": "A" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "histogram_quantile(0.75, rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "75%", + "refId": "B" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "histogram_quantile(0.90, rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "90%", + "refId": "C" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "histogram_quantile(0.99, rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "99%", + "refId": "D" + } + ], + "title": "Events persisted, by number of state DAG forward extremities in room (quantiles)", + "type": "timeseries" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "description": "Colour reflects the number of events persisted to rooms with the given number of state DAG forward extremities, or fewer.", + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 50 + }, + "id": 127, + "options": { + "calculate": false, + "calculation": {}, + "cellGap": 1, + "cellValues": {}, + "color": { + "exponent": 0.5, + "fill": "#5794F2", + "min": 0, + "mode": "opacity", + "reverse": false, + "scale": "exponential", + "scheme": "Oranges", + "steps": 128 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": true + }, + "rowsFrame": { + "layout": "auto" + }, + "showValue": "never", + "tooltip": { + "show": true, + "yHistogram": true + }, + "yAxis": { + "axisPlacement": "left", + "decimals": 0, + "reverse": false, + "unit": "short" + } + }, + "pluginVersion": "9.2.2", + "targets": [ + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0)", + "format": "heatmap", + "intervalFactor": 1, + "legendFormat": "{{le}}", + "refId": "A" + } + ], + "title": "Events persisted, by number of state DAG forward extremities in room (heatmap)", + "type": "heatmap" } ], "title": "Extremities", @@ -7711,4 +7860,4 @@ "uid": "000000012", "version": 1, "weekStart": "" -} +} \ No newline at end of file diff --git a/docs/development/synapse_architecture/streams.md b/docs/development/synapse_architecture/streams.md index b4057199a9..e7ab79091e 100644 --- a/docs/development/synapse_architecture/streams.md +++ b/docs/development/synapse_architecture/streams.md @@ -179,6 +179,13 @@ necessary registration and event handling. - don't forget the super call - add local-only [invalidations to your writer transactions](https://github.com/element-hq/synapse/blob/4367fb2d078c52959aeca0fe6874539c53e8360d/synapse/storage/databases/main/thread_subscriptions.py#L201) +**Update docs:** +- Update the [*Stream + writers*](https://github.com/element-hq/synapse/blob/develop/docs/workers.md#stream-writers) + section in the worker docs with a new section for the stream +- If this stream can only be handled by specific workers, add a new section to the + [upgrade notes](https://github.com/element-hq/synapse/blob/develop/docs/upgrade.md). + **For streams to be used in sync:** - add a new field to [`StreamToken`](https://github.com/element-hq/synapse/blob/4367fb2d078c52959aeca0fe6874539c53e8360d/synapse/types/__init__.py#L1003) - add a new [`StreamKeyType`](https://github.com/element-hq/synapse/blob/4367fb2d078c52959aeca0fe6874539c53e8360d/synapse/types/__init__.py#L999) diff --git a/poetry.lock b/poetry.lock index ef5c13684d..fd8f1a43c9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -26,15 +26,15 @@ files = [ [[package]] name = "authlib" -version = "1.6.9" +version = "1.6.11" description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients." optional = true python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"oidc\" or extra == \"jwt\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"jwt\" or extra == \"oidc\"" files = [ - {file = "authlib-1.6.9-py2.py3-none-any.whl", hash = "sha256:f08b4c14e08f0861dc18a32357b33fbcfd2ea86cfe3fe149484b4d764c4a0ac3"}, - {file = "authlib-1.6.9.tar.gz", hash = "sha256:d8f2421e7e5980cc1ddb4e32d3f5fa659cfaf60d8eaf3281ebed192e4ab74f04"}, + {file = "authlib-1.6.11-py2.py3-none-any.whl", hash = "sha256:c8687a9a26451c51a34a06fa17bb97cb15bba46a6a626755e2d7f50da8bff3e3"}, + {file = "authlib-1.6.11.tar.gz", hash = "sha256:64db35b9b01aeccb4715a6c9a6613a06f2bd7be2ab9d2eb89edd1dfc7580a38f"}, ] [package.dependencies] @@ -62,7 +62,7 @@ description = "Backport of CPython tarfile module" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version < \"3.12\" and platform_machine != \"ppc64le\" and platform_machine != \"s390x\"" +markers = "platform_machine != \"ppc64le\" and platform_machine != \"s390x\" and python_version < \"3.12\"" files = [ {file = "backports.tarfile-1.2.0-py3-none-any.whl", hash = "sha256:77e284d754527b01fb1e6fa8a1afe577858ebe4e9dad8919e34c862cb399bc34"}, {file = "backports_tarfile-1.2.0.tar.gz", hash = "sha256:d75e02c268746e1b8144c278978b6e98e85de6ad16f8e4b0844a154557eca991"}, @@ -531,7 +531,7 @@ description = "XML bomb protection for Python stdlib modules" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" groups = ["main"] -markers = "extra == \"saml2\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"saml2\"" files = [ {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"}, {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, @@ -556,7 +556,7 @@ description = "XPath 1.0/2.0/3.0/3.1 parsers and selectors for ElementTree and l optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"saml2\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"saml2\"" files = [ {file = "elementpath-4.8.0-py3-none-any.whl", hash = "sha256:5393191f84969bcf8033b05ec4593ef940e58622ea13cefe60ecefbbf09d58d9"}, {file = "elementpath-4.8.0.tar.gz", hash = "sha256:5822a2560d99e2633d95f78694c7ff9646adaa187db520da200a8e9479dc46ae"}, @@ -606,7 +606,7 @@ description = "Python wrapper for hiredis" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"redis\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"redis\"" files = [ {file = "hiredis-3.3.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:f525734382a47f9828c9d6a1501522c78d5935466d8e2be1a41ba40ca5bb922b"}, {file = "hiredis-3.3.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:6e2e1024f0a021777740cb7c633a0efb2c4a4bc570f508223a8dcbcf79f99ef9"}, @@ -889,7 +889,7 @@ description = "Read metadata from Python packages" optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version < \"3.12\" and platform_machine != \"ppc64le\" and platform_machine != \"s390x\"" +markers = "platform_machine != \"ppc64le\" and platform_machine != \"s390x\" and python_version < \"3.12\"" files = [ {file = "importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151"}, {file = "importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb"}, @@ -930,7 +930,7 @@ description = "Jaeger Python OpenTracing Tracer implementation" optional = true python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"opentracing\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"opentracing\"" files = [ {file = "jaeger-client-4.8.0.tar.gz", hash = "sha256:3157836edab8e2c209bd2d6ae61113db36f7ee399e66b1dcbb715d87ab49bfe0"}, ] @@ -1122,7 +1122,7 @@ description = "A strictly RFC 4510 conforming LDAP V3 pure Python client library optional = true python-versions = "*" groups = ["main"] -markers = "extra == \"matrix-synapse-ldap3\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"matrix-synapse-ldap3\"" files = [ {file = "ldap3-2.9.1-py2.py3-none-any.whl", hash = "sha256:5869596fc4948797020d3f03b7939da938778a0f9e2009f7a072ccf92b8e8d70"}, {file = "ldap3-2.9.1.tar.gz", hash = "sha256:f3e7fc4718e3f09dda568b57100095e0ce58633bcabbed8667ce3f8fbaa4229f"}, @@ -1239,7 +1239,7 @@ description = "Powerful and Pythonic XML processing library combining libxml2/li optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"url-preview\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"url-preview\"" files = [ {file = "lxml-6.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e77dd455b9a16bbd2a5036a63ddbd479c19572af81b624e79ef422f929eef388"}, {file = "lxml-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d444858b9f07cefff6455b983aea9a67f7462ba1f6cbe4a21e8bf6791bf2153"}, @@ -1553,7 +1553,7 @@ description = "An LDAP3 auth provider for Synapse" optional = true python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"matrix-synapse-ldap3\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"matrix-synapse-ldap3\"" files = [ {file = "matrix_synapse_ldap3-0.4.0-py3-none-any.whl", hash = "sha256:bf080037230d2af5fd3639cb87266de65c1cad7a68ea206278c5b4bf9c1a17f3"}, {file = "matrix_synapse_ldap3-0.4.0.tar.gz", hash = "sha256:cff52ba780170de5e6e8af42863d2648ee23f3bf0a9fea6db52372f9fc00be2b"}, @@ -1834,7 +1834,7 @@ description = "OpenTracing API for Python. See documentation at http://opentraci optional = true python-versions = "*" groups = ["main"] -markers = "extra == \"opentracing\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"opentracing\"" files = [ {file = "opentracing-2.4.0.tar.gz", hash = "sha256:a173117e6ef580d55874734d1fa7ecb6f3655160b8b8974a2a1e98e5ec9c840d"}, ] @@ -2032,7 +2032,7 @@ description = "psycopg2 - Python-PostgreSQL Database Adapter" optional = true python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"postgres\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"postgres\"" files = [ {file = "psycopg2-2.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:103e857f46bb76908768ead4e2d0ba1d1a130e7b8ed77d3ae91e8b33481813e8"}, {file = "psycopg2-2.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:210daed32e18f35e3140a1ebe059ac29209dd96468f2f7559aa59f75ee82a5cb"}, @@ -2050,7 +2050,7 @@ description = ".. image:: https://travis-ci.org/chtd/psycopg2cffi.svg?branch=mas optional = true python-versions = "*" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\" and (extra == \"postgres\" or extra == \"all\")" +markers = "platform_python_implementation == \"PyPy\" and (extra == \"all\" or extra == \"postgres\")" files = [ {file = "psycopg2cffi-2.9.0.tar.gz", hash = "sha256:7e272edcd837de3a1d12b62185eb85c45a19feda9e62fa1b120c54f9e8d35c52"}, ] @@ -2066,7 +2066,7 @@ description = "A Simple library to enable psycopg2 compatability" optional = true python-versions = "*" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\" and (extra == \"postgres\" or extra == \"all\")" +markers = "platform_python_implementation == \"PyPy\" and (extra == \"all\" or extra == \"postgres\")" files = [ {file = "psycopg2cffi-compat-1.1.tar.gz", hash = "sha256:d25e921748475522b33d13420aad5c2831c743227dc1f1f2585e0fdb5c914e05"}, ] @@ -2348,7 +2348,7 @@ description = "A development tool to measure, monitor and analyze the memory beh optional = true python-versions = ">=3.6" groups = ["main"] -markers = "extra == \"cache-memory\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"cache-memory\"" files = [ {file = "Pympler-1.0.1-py3-none-any.whl", hash = "sha256:d260dda9ae781e1eab6ea15bacb84015849833ba5555f141d2d9b7b7473b307d"}, {file = "Pympler-1.0.1.tar.gz", hash = "sha256:993f1a3599ca3f4fcd7160c7545ad06310c9e12f70174ae7ae8d4e25f6c5d3fa"}, @@ -2480,7 +2480,7 @@ description = "Python implementation of SAML Version 2 Standard" optional = true python-versions = ">=3.9,<4.0" groups = ["main"] -markers = "extra == \"saml2\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"saml2\"" files = [ {file = "pysaml2-7.5.0-py3-none-any.whl", hash = "sha256:bc6627cc344476a83c757f440a73fda1369f13b6fda1b4e16bca63ffbabb5318"}, {file = "pysaml2-7.5.0.tar.gz", hash = "sha256:f36871d4e5ee857c6b85532e942550d2cf90ea4ee943d75eb681044bbc4f54f7"}, @@ -2505,7 +2505,7 @@ description = "Extensions to the standard Python datetime module" optional = true python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" groups = ["main"] -markers = "extra == \"saml2\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"saml2\"" files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -2533,7 +2533,7 @@ description = "World timezone definitions, modern and historical" optional = true python-versions = "*" groups = ["main"] -markers = "extra == \"saml2\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"saml2\"" files = [ {file = "pytz-2026.1.post1-py2.py3-none-any.whl", hash = "sha256:f2fd16142fda348286a75e1a524be810bb05d444e5a081f37f7affc635035f7a"}, {file = "pytz-2026.1.post1.tar.gz", hash = "sha256:3378dde6a0c3d26719182142c56e60c7f9af7e968076f31aae569d72a0358ee1"}, @@ -2938,7 +2938,7 @@ description = "Python client for Sentry (https://sentry.io)" optional = true python-versions = ">=3.6" groups = ["main"] -markers = "extra == \"sentry\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"sentry\"" files = [ {file = "sentry_sdk-2.57.0-py2.py3-none-any.whl", hash = "sha256:812c8bf5ff3d2f0e89c82f5ce80ab3a6423e102729c4706af7413fd1eb480585"}, {file = "sentry_sdk-2.57.0.tar.gz", hash = "sha256:4be8d1e71c32fb27f79c577a337ac8912137bba4bcbc64a4ec1da4d6d8dc5199"}, @@ -3138,7 +3138,7 @@ description = "Tornado IOLoop Backed Concurrent Futures" optional = true python-versions = "*" groups = ["main"] -markers = "extra == \"opentracing\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"opentracing\"" files = [ {file = "threadloop-1.0.2-py2-none-any.whl", hash = "sha256:5c90dbefab6ffbdba26afb4829d2a9df8275d13ac7dc58dccb0e279992679599"}, {file = "threadloop-1.0.2.tar.gz", hash = "sha256:8b180aac31013de13c2ad5c834819771992d350267bddb854613ae77ef571944"}, @@ -3154,7 +3154,7 @@ description = "Python bindings for the Apache Thrift RPC system" optional = true python-versions = "*" groups = ["main"] -markers = "extra == \"opentracing\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"opentracing\"" files = [ {file = "thrift-0.22.0.tar.gz", hash = "sha256:42e8276afbd5f54fe1d364858b6877bc5e5a4a5ed69f6a005b94ca4918fe1466"}, ] @@ -3220,7 +3220,6 @@ files = [ {file = "tomli-2.4.0-py3-none-any.whl", hash = "sha256:1f776e7d669ebceb01dee46484485f43a4048746235e683bcdffacdf1fb4785a"}, {file = "tomli-2.4.0.tar.gz", hash = "sha256:aa89c3f6c277dd275d8e243ad24f3b5e701491a860d5121f2cdd399fbb31fc9c"}, ] -markers = {main = "python_version < \"3.14\""} [[package]] name = "tornado" @@ -3229,7 +3228,7 @@ description = "Tornado is a Python web framework and asynchronous networking lib optional = true python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"opentracing\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"opentracing\"" files = [ {file = "tornado-6.5.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:487dc9cc380e29f58c7ab88f9e27cdeef04b2140862e5076a66fb6bb68bb1bfa"}, {file = "tornado-6.5.5-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:65a7f1d46d4bb41df1ac99f5fcb685fb25c7e61613742d5108b010975a9a6521"}, @@ -3361,7 +3360,7 @@ description = "non-blocking redis client for python" optional = true python-versions = "*" groups = ["main"] -markers = "extra == \"redis\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"redis\"" files = [ {file = "txredisapi-1.4.11-py3-none-any.whl", hash = "sha256:ac64d7a9342b58edca13ef267d4fa7637c1aa63f8595e066801c1e8b56b22d0b"}, {file = "txredisapi-1.4.11.tar.gz", hash = "sha256:3eb1af99aefdefb59eb877b1dd08861efad60915e30ad5bf3d5bf6c5cedcdbc6"}, @@ -3622,7 +3621,7 @@ description = "An XML Schema validator and decoder" optional = true python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"saml2\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"saml2\"" files = [ {file = "xmlschema-2.5.1-py3-none-any.whl", hash = "sha256:ec2b2a15c8896c1fcd14dcee34ca30032b99456c3c43ce793fdb9dca2fb4b869"}, {file = "xmlschema-2.5.1.tar.gz", hash = "sha256:4f7497de6c8b6dc2c28ad7b9ed6e21d186f4afe248a5bea4f54eedab4da44083"}, @@ -3643,7 +3642,7 @@ description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version < \"3.12\" and platform_machine != \"ppc64le\" and platform_machine != \"s390x\"" +markers = "platform_machine != \"ppc64le\" and platform_machine != \"s390x\" and python_version < \"3.12\"" files = [ {file = "zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e"}, {file = "zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166"}, @@ -3756,4 +3755,4 @@ url-preview = ["lxml"] [metadata] lock-version = "2.1" python-versions = ">=3.10.0,<4.0.0" -content-hash = "ef0540b89c417a69668f551688bd0974256ea7a580044f3954a76bdf0d8fe7c9" +content-hash = "8d994f1fc65664b2a04e1de78df4d1f06f3d99b39f95db16763790f2ee0aff11" diff --git a/pyproject.toml b/pyproject.toml index 6cb7c4de9e..44b8296438 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,7 @@ saml2 = [ "defusedxml>=0.7.1", # via pysaml2 "pytz>=2018.3", # via pysaml2 ] -oidc = ["authlib>=0.15.1"] +oidc = ["authlib>=1.6.11"] url-preview = ["lxml>=4.6.3"] sentry = ["sentry-sdk>=0.7.2"] opentracing = [ @@ -179,7 +179,7 @@ all = [ # saml2 "pysaml2>=4.5.0", # oidc and jwt - "authlib>=0.15.1", + "authlib>=1.6.11", # url-preview "lxml>=4.6.3", # sentry diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs index 21d3b8c435..6fd3d06b00 100644 --- a/rust/src/events/internal_metadata.rs +++ b/rust/src/events/internal_metadata.rs @@ -65,6 +65,7 @@ enum EventInternalMetadataData { DelayId(Box), TokenId(i64), DeviceId(Box), + CalculatedAuthEventIDs(Vec), // MSC4242: State DAGs } impl EventInternalMetadataData { @@ -140,6 +141,10 @@ impl EventInternalMetadataData { pyo3::intern!(py, "device_id"), o.into_pyobject(py).unwrap_infallible().into_any(), ), + EventInternalMetadataData::CalculatedAuthEventIDs(o) => ( + pyo3::intern!(py, "calculated_auth_event_ids"), + o.into_pyobject(py).unwrap().into_any(), + ), } } @@ -218,6 +223,11 @@ impl EventInternalMetadataData { .map(String::into_boxed_str) .with_context(|| format!("'{key_str}' has invalid type"))?, ), + "calculated_auth_event_ids" => EventInternalMetadataData::CalculatedAuthEventIDs( + value + .extract() + .with_context(|| format!("'{key_str}' has invalid type"))?, + ), _ => return Ok(None), }; @@ -395,6 +405,10 @@ impl EventInternalMetadataInner { get_property_opt!(self, DelayId).map(|s| s.deref()) } + pub fn get_calculated_auth_event_ids(&self) -> Option<&Vec> { + get_property_opt!(self, CalculatedAuthEventIDs) + } + pub fn get_token_id(&self) -> Option { get_property_opt!(self, TokenId).copied() } @@ -456,6 +470,10 @@ impl EventInternalMetadataInner { pub fn set_device_id(&mut self, obj: String) { set_property!(self, DeviceId, obj.into_boxed_str()); } + + pub fn set_calculated_auth_event_ids(&mut self, obj: Vec) { + set_property!(self, CalculatedAuthEventIDs, obj); + } } #[pyclass(frozen)] @@ -722,6 +740,21 @@ impl EventInternalMetadata { Ok(()) } + /// The calculated auth event IDs, if it was set when the event was created. + #[getter] + fn get_calculated_auth_event_ids(&self) -> PyResult> { + let guard = self.read_inner()?; + attr_err( + guard.get_calculated_auth_event_ids().cloned(), + "calculated_auth_event_ids", + ) + } + #[setter] + fn set_calculated_auth_event_ids(&self, obj: Vec) -> PyResult<()> { + self.write_inner()?.set_calculated_auth_event_ids(obj); + Ok(()) + } + /// The delay ID, set only if the event was a delayed event. #[getter] fn get_delay_id(&self) -> PyResult { diff --git a/rust/src/room_versions.rs b/rust/src/room_versions.rs index fbcc32516a..dbc962174d 100644 --- a/rust/src/room_versions.rs +++ b/rust/src/room_versions.rs @@ -47,6 +47,9 @@ impl EventFormatVersions { /// MSC4291 room IDs as hashes: introduced for room HydraV11 #[classattr] const ROOM_V11_HYDRA_PLUS: i32 = 4; + /// MSC4242 state DAGs: adds prev_state_events, removes auth_events + #[classattr] + const ROOM_VMSC4242: i32 = 5; } /// Enum to identify the state resolution algorithms. @@ -146,6 +149,14 @@ pub struct RoomVersion { /// /// In these room versions, we are stricter with event size validation. pub strict_event_byte_limits_room_versions: bool, + /// MSC4242: State DAGs. Creates events with prev_state_events instead of auth_events and derives + /// state from it. Events are always processed in causal order without any gaps in the DAG + /// (prev_state_events are always known), guaranteeing that processed events have a path to the + /// create event. This is an emergent property of state DAGs as asserting that there is a path + /// to the create event every time we insert an event would be prohibitively expensive. + /// This is similar to how doubly-linked lists can potentially not refer to previous items correctly + /// without verifying the list's integrity, but doing it on every insert is too expensive. + pub msc4242_state_dags: bool, } const ROOM_VERSION_V1: RoomVersion = RoomVersion { @@ -170,6 +181,7 @@ const ROOM_VERSION_V1: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V2: RoomVersion = RoomVersion { @@ -194,6 +206,7 @@ const ROOM_VERSION_V2: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V3: RoomVersion = RoomVersion { @@ -218,6 +231,7 @@ const ROOM_VERSION_V3: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V4: RoomVersion = RoomVersion { @@ -242,6 +256,7 @@ const ROOM_VERSION_V4: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V5: RoomVersion = RoomVersion { @@ -266,6 +281,7 @@ const ROOM_VERSION_V5: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V6: RoomVersion = RoomVersion { @@ -290,6 +306,7 @@ const ROOM_VERSION_V6: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V7: RoomVersion = RoomVersion { @@ -314,6 +331,7 @@ const ROOM_VERSION_V7: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V8: RoomVersion = RoomVersion { @@ -338,6 +356,7 @@ const ROOM_VERSION_V8: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V9: RoomVersion = RoomVersion { @@ -362,6 +381,7 @@ const ROOM_VERSION_V9: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V10: RoomVersion = RoomVersion { @@ -386,6 +406,7 @@ const ROOM_VERSION_V10: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; /// MSC3389 (Redaction changes for events with a relation) based on room version "10". @@ -411,6 +432,7 @@ const ROOM_VERSION_MSC3389V10: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: true, + msc4242_state_dags: false, }; /// MSC1767 (Extensible Events) based on room version "10". @@ -436,6 +458,7 @@ const ROOM_VERSION_MSC1767V10: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; /// MSC3757 (Restricting who can overwrite a state event) based on room version "10". @@ -461,6 +484,7 @@ const ROOM_VERSION_MSC3757V10: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V11: RoomVersion = RoomVersion { @@ -485,6 +509,7 @@ const ROOM_VERSION_V11: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: true, // Changed from v10 + msc4242_state_dags: false, }; /// MSC3757 (Restricting who can overwrite a state event) based on room version "11". @@ -510,6 +535,7 @@ const ROOM_VERSION_MSC3757V11: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: true, + msc4242_state_dags: false, }; const ROOM_VERSION_HYDRA_V11: RoomVersion = RoomVersion { @@ -534,6 +560,7 @@ const ROOM_VERSION_HYDRA_V11: RoomVersion = RoomVersion { msc4289_creator_power_enabled: true, // Changed from v11 msc4291_room_ids_as_hashes: true, // Changed from v11 strict_event_byte_limits_room_versions: true, + msc4242_state_dags: false, }; const ROOM_VERSION_V12: RoomVersion = RoomVersion { @@ -558,6 +585,32 @@ const ROOM_VERSION_V12: RoomVersion = RoomVersion { msc4289_creator_power_enabled: true, // Changed from v11 msc4291_room_ids_as_hashes: true, // Changed from v11 strict_event_byte_limits_room_versions: true, + msc4242_state_dags: false, +}; + +const ROOM_VERSION_MSC4242V12: RoomVersion = RoomVersion { + identifier: "org.matrix.msc4242.12", + disposition: RoomDisposition::UNSTABLE, + event_format: EventFormatVersions::ROOM_VMSC4242, + state_res: StateResolutionVersions::V2_1, + enforce_key_validity: true, + special_case_aliases_auth: false, + strict_canonicaljson: true, + limit_notifications_power_levels: true, + implicit_room_creator: true, + updated_redaction_rules: true, + restricted_join_rule: true, + restricted_join_rule_fix: true, + knock_join_rule: true, + msc3389_relation_redactions: false, + knock_restricted_join_rule: true, + enforce_int_power_levels: true, + msc3931_push_features: &[], + msc3757_enabled: false, + msc4289_creator_power_enabled: true, + msc4291_room_ids_as_hashes: true, + strict_event_byte_limits_room_versions: true, + msc4242_state_dags: true, }; /// Helper class for managing the known room versions, and providing dict-like @@ -800,6 +853,10 @@ impl RoomVersions { fn V12(py: Python<'_>) -> PyResult> { ROOM_VERSION_V12.into_py_any(py) } + #[classattr] + fn MSC4242v12(py: Python<'_>) -> PyResult> { + ROOM_VERSION_MSC4242V12.into_py_any(py) + } } /// Called when registering modules with python. @@ -814,11 +871,12 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> child_module.add_class::()?; // Build KNOWN_EVENT_FORMAT_VERSIONS as a frozenset - let known_ef: [i32; 4] = [ + let known_ef: [i32; 5] = [ EventFormatVersions::ROOM_V1_V2, EventFormatVersions::ROOM_V3, EventFormatVersions::ROOM_V4_PLUS, EventFormatVersions::ROOM_V11_HYDRA_PLUS, + EventFormatVersions::ROOM_VMSC4242, ]; let known_event_format_versions = PyFrozenSet::new(py, known_ef)?; child_module.add("KNOWN_EVENT_FORMAT_VERSIONS", known_event_format_versions)?; diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 702c7e3246..f1a7771568 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -479,6 +479,12 @@ class ExperimentalConfig(Config): # Enable room version (and thus applicable push rules from MSC3931/3932) KNOWN_ROOM_VERSIONS.add_room_version(RoomVersions.MSC1767v10) + # MSC4242: State DAGs + self.msc4242_enabled: bool = experimental.get("msc4242_enabled", False) + if self.msc4242_enabled: + # Enable the room version + KNOWN_ROOM_VERSIONS.add_room_version(RoomVersions.MSC4242v12) + # MSC3391: Removing account data. self.msc3391_enabled = experimental.get("msc3391_enabled", False) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index bf239e660d..ca528ae235 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -61,7 +61,7 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersion, ) -from synapse.events import is_creator +from synapse.events import FrozenEventVMSC4242, is_creator from synapse.state import CREATE_KEY from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( @@ -186,6 +186,70 @@ async def check_state_independent_auth_rules( # 1.5 Otherwise, allow return + # State DAGs 2. Considering the event's prev_state_events: + if event.room_version.msc4242_state_dags: + prev_state_events_ids = set(cast(FrozenEventVMSC4242, event).prev_state_events) + # Fetch all of the `prev_state_events` + prev_state_events = {} + # Try to load the `prev_state_events` from `batched_auth_events` initially as + # that can save us a database hit. + if batched_auth_events is not None: + prev_state_events = { + event_id: value + for event_id in prev_state_events_ids + if (value := batched_auth_events.get(event_id)) is not None + } + # Fetch the rest of the `prev_state_events` + missing_prev_state_events_ids = prev_state_events_ids - set( + prev_state_events.keys() + ) + fetched_prev_state_events = await store.get_events( + missing_prev_state_events_ids, + redact_behaviour=EventRedactBehaviour.as_is, + allow_rejected=True, + ) + prev_state_events.update(fetched_prev_state_events) + if len(prev_state_events) != len(prev_state_events_ids): + # we should have all the `prev_state_events` by now, so if we do not, that suggests + # a Synapse programming error + known_prev_state_event_ids = set(prev_state_events) + raise AssertionError( + f"Event {event.event_id} has unknown prev_state_events " + + f"({len(prev_state_events)}/{len(prev_state_events_ids)} known)" + + f"{prev_state_events_ids - known_prev_state_event_ids} missing " + + f"out of {prev_state_events_ids}" + ) + for prev_state_event in prev_state_events.values(): + # 2.1 If there are entries which do not belong in the same room, reject. + if prev_state_event.room_id != event.room_id: + raise AuthError( + 403, + "During auth for event %s in room %s, found event %s in prev_state_events " + "which belongs to a different room %s" + % ( + event.event_id, + event.room_id, + prev_state_event.event_id, + prev_state_event.room_id, + ), + ) + # 2.2 If there are entries which do not have a state_key, reject. + if not prev_state_event.is_state(): + raise AuthError( + 403, + f"During auth for event {event.event_id} in room {event.room_id}, event has a " + + f"prev_state_event which is not state: {prev_state_event.event_id}", + ) + # 2.3 If there are entries which were themselves rejected under the checks performed on + # receipt of a PDU, reject. + if prev_state_event.rejected_reason is not None: + raise AuthError( + 403, + f"During auth for event {event.event_id} in room {event.room_id}, event has a " + + f"prev_state_event which is rejected ({prev_state_event.rejected_reason}): " + + f"{prev_state_event.event_id}", + ) + # 2. Reject if event has auth_events that: ... auth_events: ChainMap[str, EventBase] = ChainMap() if batched_auth_events: @@ -450,6 +514,12 @@ def _check_create(event: "EventBase") -> None: if event.prev_event_ids(): raise AuthError(403, "Create event has prev events") + # State DAGs 1.2 If it has any prev_state_events, reject. + if event.room_version.msc4242_state_dags: + assert isinstance(event, FrozenEventVMSC4242) + if len(event.prev_state_events) > 0: + raise AuthError(403, "Create event has prev state events") + if event.room_version.msc4291_room_ids_as_hashes: # 1.2 If the create event has a room_id, reject if "room_id" in event: diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index f48d5c4f1d..3c46d02e92 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -34,6 +34,7 @@ from typing import ( ) import attr +from typing_extensions import deprecated from unpaddedbase64 import encode_base64 from synapse.api.constants import ( @@ -44,10 +45,7 @@ from synapse.api.constants import ( ) from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.synapse_rust.events import EventInternalMetadata -from synapse.types import ( - JsonDict, - StrCollection, -) +from synapse.types import JsonDict, StateKey, StrCollection from synapse.util.caches import intern_dict from synapse.util.duration import Duration from synapse.util.frozenutils import freeze @@ -222,6 +220,9 @@ class EventBase(metaclass=abc.ABCMeta): state_key: DictProperty[str] = DictProperty("state_key") type: DictProperty[str] = DictProperty("type") + # This is a deprecated property, use `sender` instead. Only used by modules. + user_id: DictProperty[str] = DictProperty("sender") + @property def event_id(self) -> str: raise NotImplementedError() @@ -363,6 +364,11 @@ class EventBase(metaclass=abc.ABCMeta): ">" ) + # Using `__getitem__` is deprecated. Only used by modules. + @deprecated("Use attribute access instead") + def __getitem__(self, field: str) -> Any | None: + return self._dict[field] + class FrozenEvent(EventBase): format_version = EventFormatVersions.ROOM_V1_V2 # All events of this type are V1 @@ -575,9 +581,60 @@ class FrozenEventV4(FrozenEventV3): return [*self._dict["auth_events"], create_event_id] +class FrozenEventVMSC4242(FrozenEventV4): + """FrozenEventVMSC4242, which differs from FrozenEventV4 only in the addition of prev_state_events""" + + format_version = EventFormatVersions.ROOM_VMSC4242 + prev_state_events: DictProperty[list[str]] = DictProperty("prev_state_events") + + def __init__( + self, + event_dict: JsonDict, + room_version: RoomVersion, + internal_metadata_dict: JsonDict | None = None, + rejected_reason: str | None = None, + ): + # Similar to how we assert event_id isn't in V2+ events, we do the same with auth_events. + # We don't expect `auth_events` in the wire format because we calculate it from prev_state_events. + assert "auth_events" not in event_dict + super().__init__( + event_dict=event_dict, + room_version=room_version, + internal_metadata_dict=internal_metadata_dict, + rejected_reason=rejected_reason, + ) + + def auth_event_ids(self) -> StrCollection: + """Returns the list of _calculated_ auth event IDs. + + Returns: + The list of event IDs of this event's auth events + """ + # Catches cases where we accidentally call auth_event_ids() prior to calculating what they + # actually are. The exception being the m.room.create event which has no auth events. + if self.type != EventTypes.Create: + assert len(self.internal_metadata.calculated_auth_event_ids) > 0 + return self.internal_metadata.calculated_auth_event_ids + + def __repr__(self) -> str: + rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else "" + + return ( + f"<{self.__class__.__name__} " + f"{rejection}" + f"event_id={self.event_id}, " + f"type={self.get('type')}, " + f"state_key={self.get('state_key')}, " + f"prev_events={self.get('prev_events')}, " + f"prev_state_events={self.get('prev_state_events')}, " + f"outlier={self.internal_metadata.is_outlier()}" + ">" + ) + + def _event_type_from_format_version( format_version: int, -) -> type[FrozenEvent | FrozenEventV2 | FrozenEventV3]: +) -> type[FrozenEvent | FrozenEventV2 | FrozenEventV3 | FrozenEventVMSC4242]: """Returns the python type to use to construct an Event object for the given event format version. @@ -594,6 +651,8 @@ def _event_type_from_format_version( return FrozenEventV2 elif format_version == EventFormatVersions.ROOM_V4_PLUS: return FrozenEventV3 + elif format_version == EventFormatVersions.ROOM_VMSC4242: + return FrozenEventVMSC4242 elif format_version == EventFormatVersions.ROOM_V11_HYDRA_PLUS: return FrozenEventV4 else: @@ -655,6 +714,24 @@ def relation_from_event(event: EventBase) -> _EventRelation | None: return _EventRelation(parent_id, rel_type, aggregation_key) +def event_exists_in_state_dag( + event: Union["EventBase", "EventBuilder", "EventMetadata", "StateKey"], +) -> bool: + """Given an event, returns true if this event should form part of the state DAG. + Only valid for room versions which use a state DAG (MSC4242).""" + state_key = None + if isinstance(event, EventMetadata): + state_key = event.state_key + elif isinstance(event, tuple): # StateKey + # can't use StateKey else you get: + # "Subscripted generics cannot be used with class and instance checks" + state_key = event[1] + else: + state_key = event.state_key if event.is_state() else None + + return state_key is not None + + def is_creator(create: EventBase, user_id: str) -> bool: """ Return true if the provided user ID is the room creator. @@ -689,3 +766,13 @@ class StrippedStateEvent: state_key: str sender: str content: dict[str, Any] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventMetadata: + """Returned by `get_metadata_for_events`""" + + room_id: str + event_type: str + state_key: str | None + rejection_reason: str | None diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 2cd1bf6106..78eb98e1e5 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -132,6 +132,7 @@ class EventBuilder: prev_event_ids: list[str], auth_event_ids: list[str] | None, depth: int | None = None, + prev_state_events: list[str] | None = None, ) -> EventBase: """Transform into a fully signed and hashed event @@ -143,10 +144,51 @@ class EventBuilder: depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. - + prev_state_events: The event IDs to use as prev_state_events. + Only applicable on MSC4242 state DAG rooms. If this is supplied, auth_event_ids + must not be specified unless this event is part of a batch such that the builder + will be unable to compute the auth_event_ids due to the events not being persisted + yet. Returns: The signed and hashed event. """ + # If the caller specifies this, make sure the room version supports it. + if prev_state_events: + assert self.room_version.msc4242_state_dags + if self.room_version.msc4242_state_dags: + assert prev_state_events is not None + if self.room_id: + state_ids = await self._state.compute_state_after_events( + self.room_id, + prev_state_events, + state_filter=StateFilter.from_types( + auth_types_for_event(self.room_version, self) + ), + await_full_state=False, + ) + # When we create rooms we only insert the create+member events, and batch the rest. + # Therefore, we may not have state_ids from compute_state_after_events as the + # prev_state_events are unknown. If this happens, the caller provides the auth events + # to use instead. + calculated_auth_event_ids: list[ + str + ] = [] # assume it's the create event which has [] + if len(state_ids) == 0 and len(prev_state_events) > 0: + # it's a batched event, so we should have been provided the auth_events + assert auth_event_ids and len(auth_event_ids) > 0 + calculated_auth_event_ids = auth_event_ids + else: + calculated_auth_event_ids = ( + self._event_auth_handler.compute_auth_events(self, state_ids) + ) + else: + # event is a state DAG event and is the create event (room_id is not provided), + # therefore there are no auth_events. + calculated_auth_event_ids = [] + assert self.type == EventTypes.Create and self.state_key == "" + self.internal_metadata.calculated_auth_event_ids = calculated_auth_event_ids + auth_event_ids = calculated_auth_event_ids + # Create events always have empty auth_events. if self.type == EventTypes.Create and self.is_state() and self.state_key == "": auth_event_ids = [] @@ -155,6 +197,8 @@ class EventBuilder: if auth_event_ids is None: # Every non-create event must have a room ID assert self.room_id is not None + # this block must not be hit for MSC4242 rooms as it resolves state with prev_events + assert not self.room_version.msc4242_state_dags state_ids = await self._state.compute_state_after_events( self.room_id, prev_event_ids, @@ -231,7 +275,6 @@ class EventBuilder: # rejected by other servers (and so that they can be persisted in # the db) depth = min(depth, MAX_DEPTH) - event_dict: dict[str, Any] = { "auth_events": auth_events, "prev_events": prev_events, @@ -241,8 +284,6 @@ class EventBuilder: "unsigned": self.unsigned, "depth": depth, } - if self.room_id is not None: - event_dict["room_id"] = self.room_id if self.room_version.msc4291_room_ids_as_hashes: # In MSC4291: the create event has no room ID as the create event ID /is/ the room ID. @@ -262,6 +303,14 @@ class EventBuilder: auth_event_ids.remove(create_event_id) event_dict["auth_events"] = auth_event_ids + if self.room_version.msc4242_state_dags: + # Auth events are removed entirely on state DAG rooms + event_dict.pop("auth_events") + assert prev_state_events is not None + event_dict["prev_state_events"] = prev_state_events + if self.room_id is not None: + event_dict["room_id"] = self.room_id + if self.is_state(): event_dict["state_key"] = self._state_key diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ff0476f5fb..f038fb5578 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -156,6 +156,10 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic # Earlier room versions from had additional allowed keys. if not room_version.updated_redaction_rules: allowed_keys.extend(["prev_state", "membership", "origin"]) + # Custom room versions add new allowed keys and remove others + if room_version.msc4242_state_dags: + allowed_keys.extend(["prev_state_events"]) + allowed_keys.remove("auth_events") event_type = event_dict["type"] diff --git a/synapse/events/validator.py b/synapse/events/validator.py index b27f8a942a..ff22b2287f 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -63,14 +63,17 @@ class EventValidator: if event.format_version == EventFormatVersions.ROOM_V1_V2: EventID.from_string(event.event_id) - required = [ + required = { "auth_events", "content", "hashes", "prev_events", "sender", "type", - ] + } + if event.room_version.msc4242_state_dags: + required.remove("auth_events") + required.add("prev_state_events") for k in required: if k not in event: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 55151ca549..78a1900c73 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1108,6 +1108,11 @@ class FederationClient(FederationBase): SynapseError: if the chosen remote server returns a 300/400 code, or no servers successfully handle the request. """ + # See related restriction in /createRoom requests in handlers/room.py + if room_version.msc4242_state_dags: + raise UnsupportedRoomVersionError( + "Homeserver does not support this room version over federation" + ) async def send_request(destination: str) -> SendJoinResult: response = await self._do_send_join( diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index d2c1f98d7c..51a752472f 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -32,7 +32,7 @@ import attr from synapse.api.constants import Direction, EventTypes, Membership from synapse.api.errors import SynapseError -from synapse.events import EventBase +from synapse.events import EventBase, FrozenEventVMSC4242 from synapse.events.utils import FilteredEvent from synapse.types import ( JsonMapping, @@ -494,9 +494,16 @@ class AdminHandler: event_dict["redacts"] = event.event_id try: + prev_state_events = None + if room_version.msc4242_state_dags: + assert isinstance(event, FrozenEventVMSC4242) + prev_state_events = event.prev_state_events + assert prev_state_events is not None, ( + "Parent event of redaction has no `prev_state_events` which should be impossible as `prev_state_events` is a required field in MSC4242 rooms" + ) # set the prev event to the offending message to allow for redactions # to be processed in the case where the user has been kicked/banned before - # redactions are requested + # redactions are requested. ( redaction, _, @@ -505,6 +512,7 @@ class AdminHandler: event_dict, prev_event_ids=[event.event_id], ratelimit=False, + prev_state_events=prev_state_events, ) except Exception as ex: logger.info( diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9a371651fb..2225466648 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -58,6 +58,7 @@ from synapse.types import ( DeviceListUpdates, JsonDict, JsonMapping, + MultiWriterStreamToken, ScheduledTask, StrCollection, StreamKeyType, @@ -1193,7 +1194,16 @@ class DeviceWriterHandler(DeviceHandler): changes = await self.store.get_device_list_changes_in_room( room_id, device_lists_stream_id ) - local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)} + if changes is not None: + local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)} + else: + # The `device_lists_stream_id` is too old, so we need to fall back + # to looking for changes for all local users. + local_users = await self.store.get_local_users_in_room(room_id) + local_changes = await self.store.get_device_changes_for_users( + MultiWriterStreamToken(stream=device_lists_stream_id), local_users + ) + if not local_changes: return diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0aa0a16127..4032c7eca9 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -53,7 +53,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.urls import ConsentURIBuilder from synapse.event_auth import validate_event_for_room_version -from synapse.events import EventBase, relation_from_event +from synapse.events import EventBase, FrozenEventVMSC4242, relation_from_event from synapse.events.builder import EventBuilder from synapse.events.snapshot import ( EventContext, @@ -589,6 +589,7 @@ class EventCreationHandler: state_map: StateMap[str] | None = None, for_batch: bool = False, current_state_group: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[EventBase, UnpersistedEventContextBase]: """ @@ -644,6 +645,10 @@ class EventCreationHandler: current_state_group: the current state group, used only for creating events for batch persisting + prev_state_events: + The state event IDs which represent the current forward extremities of the state DAG. + Only applicable on room versions which use a state DAG (MSC4242). + delay_id: The delay ID of this event, if it was a delayed event. Raises: @@ -748,6 +753,7 @@ class EventCreationHandler: state_map=state_map, for_batch=for_batch, current_state_group=current_state_group, + prev_state_events=prev_state_events, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -976,6 +982,7 @@ class EventCreationHandler: ignore_shadow_ban: bool = False, outlier: bool = False, depth: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[EventBase, int]: """ @@ -1005,6 +1012,9 @@ class EventCreationHandler: depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + prev_state_events: + The state event IDs which represent the current forward extremities of the state DAG. + Only applicable on room versions which use a state DAG (MSC4242). delay_id: The delay ID of this event, if it was a delayed event. Returns: @@ -1102,6 +1112,7 @@ class EventCreationHandler: ignore_shadow_ban=ignore_shadow_ban, outlier=outlier, depth=depth, + prev_state_events=prev_state_events, delay_id=delay_id, ) @@ -1116,6 +1127,7 @@ class EventCreationHandler: ignore_shadow_ban: bool = False, outlier: bool = False, depth: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[EventBase, int]: room_id = event_dict["room_id"] @@ -1145,6 +1157,7 @@ class EventCreationHandler: state_event_ids=state_event_ids, outlier=outlier, depth=depth, + prev_state_events=prev_state_events, delay_id=delay_id, ) context = await unpersisted_context.persist(event) @@ -1240,6 +1253,7 @@ class EventCreationHandler: state_map: StateMap[str] | None = None, for_batch: bool = False, current_state_group: int | None = None, + prev_state_events: list[str] | None = None, ) -> tuple[EventBase, UnpersistedEventContextBase]: """Create a new event for a local client. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for @@ -1281,9 +1295,30 @@ class EventCreationHandler: current_state_group: the current state group, used only for creating events for batch persisting + prev_state_events: + The state event IDs which represent the current forward extremities of the state DAG. + Only applicable on room versions which use a state DAG (MSC4242). + If unset, populates them from the current state dag forward extremities. + Returns: Tuple of created event, UnpersistedEventContext """ + if builder.room_version.msc4242_state_dags: + assert auth_event_ids is None + # (kegan) I can't find any call-site which uses this. We can't risk letting in + # untrusted input, so for now assert that we aren't told about any state. + assert state_event_ids is None + + if builder.room_id: + if prev_state_events is None: + prev_state_events = list( + await self.store.get_state_dag_extremities(builder.room_id) + ) + else: + # create event doesn't need prev_state_events to be fetched, but it must be non-None. + assert builder.type == EventTypes.Create and builder.state_key == "" + prev_state_events = [] + # Strip down the state_event_ids to only what we need to auth the event. # For example, we don't need extra m.room.member that don't match event.sender if state_event_ids is not None: @@ -1357,7 +1392,10 @@ class EventCreationHandler: assert state_map is not None auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) event = await builder.build( - prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth + prev_event_ids=prev_event_ids, + auth_event_ids=auth_ids, + depth=depth, + prev_state_events=prev_state_events, ) context: UnpersistedEventContextBase = ( @@ -1374,6 +1412,7 @@ class EventCreationHandler: prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, depth=depth, + prev_state_events=prev_state_events, ) # Pass on the outlier property from the builder to the event @@ -1563,6 +1602,20 @@ class EventCreationHandler: auth_event = event_id_to_event.get(event_id) if auth_event: batched_auth_events[event_id] = auth_event + if event.room_version.msc4242_state_dags: + assert isinstance(event, FrozenEventVMSC4242) + # State DAG rooms will check that the prev_state_events are not rejected. + # To do that, we need to make sure we pass in the prev_state_events as + # batched_auth_events, else we will fail the event due to the + # prev_state_events not existing in the database. + for prev_state_event_id in event.prev_state_events: + prev_state_event = event_id_to_event.get( + prev_state_event_id + ) + if prev_state_event: + batched_auth_events[prev_state_event_id] = ( + prev_state_event + ) await self._event_auth_handler.check_auth_rules_from_context( event, batched_auth_events ) @@ -1817,7 +1870,10 @@ class EventCreationHandler: # set for a while, so that the expiry time is reset. state_entry = await self.state.resolve_state_groups_for_events( - event.room_id, event_ids=event.prev_event_ids() + event.room_id, + event_ids=event.prev_state_events + if isinstance(event, FrozenEventVMSC4242) + else event.prev_event_ids(), ) if state_entry.state_group: @@ -2360,9 +2416,16 @@ class EventCreationHandler: # case. prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) + prev_state_events = None + if original_event.room_version.msc4242_state_dags: + prev_state_events = list( + await self.store.get_state_dag_extremities(builder.room_id) + ) + event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=None, + prev_state_events=prev_state_events, ) # we rebuild the event context, to be on the safe side. If nothing else, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9074d7916b..f110be0a2f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -65,7 +65,7 @@ from synapse.api.filtering import Filter from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version -from synapse.events import EventBase +from synapse.events import EventBase, event_exists_in_state_dag from synapse.events.snapshot import UnpersistedEventContext from synapse.events.utils import FilteredEvent, copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations @@ -1237,6 +1237,10 @@ class RoomCreationHandler: creation_content = config.get("creation_content", {}) # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier + # We do not currently support federating state DAG rooms. + # See related restriction in /send_join requests in federation_client.py. + if room_version.msc4242_state_dags: + creation_content[EventContentFields.FEDERATE] = False # trusted private chats have the invited users marked as additional creators if ( @@ -1486,6 +1490,11 @@ class RoomCreationHandler: # the most recently created event prev_event: list[str] = [] + # This should be the most recently created state event as we create each event + prev_state_events: list[str] | None = ( + [] if room_version.msc4242_state_dags else None + ) + # a map of event types, state keys -> event_ids. We collect these mappings this as events are # created (but not persisted to the db) to determine state for future created events # (as this info can't be pulled from the db) @@ -1512,6 +1521,7 @@ class RoomCreationHandler: """ nonlocal depth nonlocal prev_event + nonlocal prev_state_events # Create the event dictionary. event_dict = {"type": etype, "content": content} @@ -1525,6 +1535,7 @@ class RoomCreationHandler: creator, event_dict, prev_event_ids=prev_event, + prev_state_events=prev_state_events, depth=depth, # Take a copy to ensure each event gets a unique copy of # state_map since it is modified below. @@ -1535,7 +1546,8 @@ class RoomCreationHandler: depth += 1 prev_event = [new_event.event_id] state_map[(new_event.type, new_event.state_key)] = new_event.event_id - + if room_version.msc4242_state_dags and event_exists_in_state_dag(new_event): + prev_state_events = [new_event.event_id] return new_event, new_unpersisted_context preset_config, config = self._room_preset_config(room_config) @@ -1568,6 +1580,8 @@ class RoomCreationHandler: ignore_shadow_ban=True, ) last_sent_event_id = ev.event_id + if room_version.msc4242_state_dags: + prev_state_events = [ev.event_id] member_event_id, _ = await self.room_member_handler.update_membership( creator, @@ -1579,8 +1593,11 @@ class RoomCreationHandler: new_room=True, prev_event_ids=[last_sent_event_id], depth=depth, + prev_state_events=prev_state_events, ) prev_event = [member_event_id] + if room_version.msc4242_state_dags: + prev_state_events = [member_event_id] # update the depth and state map here as the membership event has been created # through a different code path diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b2e678e90e..236c8ca03c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -36,9 +36,11 @@ from synapse.api.constants import ( from synapse.api.errors import ( AuthError, Codes, + NotFoundError, PartialStateConflictError, ShadowBanError, SynapseError, + UnsupportedRoomVersionError, ) from synapse.api.ratelimiting import Ratelimiter from synapse.event_auth import get_named_level, get_power_level_event @@ -408,6 +410,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent: bool = True, outlier: bool = False, origin_server_ts: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[str, int]: """ @@ -494,6 +497,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): depth=depth, require_consent=require_consent, outlier=outlier, + prev_state_events=prev_state_events, delay_id=delay_id, ) context = await unpersisted_context.persist(event) @@ -590,6 +594,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): state_event_ids: list[str] | None = None, depth: int | None = None, origin_server_ts: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[str, int]: """Update a user's membership in a room. @@ -684,6 +689,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): state_event_ids=state_event_ids, depth=depth, origin_server_ts=origin_server_ts, + prev_state_events=prev_state_events, delay_id=delay_id, ) @@ -707,6 +713,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): state_event_ids: list[str] | None = None, depth: int | None = None, origin_server_ts: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[str, int]: """Helper for update_membership. @@ -951,10 +958,21 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent=require_consent, outlier=outlier, origin_server_ts=origin_server_ts, + prev_state_events=prev_state_events, delay_id=delay_id, ) - latest_event_ids = await self.store.get_prev_events_for_room(room_id) + is_state_dags = False + try: + room_version = await self.store.get_room_version(room_id) + is_state_dags = room_version.msc4242_state_dags + except (NotFoundError, UnsupportedRoomVersionError): + pass + + if is_state_dags: + latest_event_ids = list(await self.store.get_state_dag_extremities(room_id)) + else: + latest_event_ids = await self.store.get_prev_events_for_room(room_id) is_partial_state_room = await self.store.is_partial_state_room(room_id) partial_state_before_join = await self.state_handler.compute_state_after_events( @@ -1165,6 +1183,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # see: https://github.com/matrix-org/synapse/issues/7139 if len(latest_event_ids) == 0: latest_event_ids = [invite.event_id] + if invite.room_version.msc4242_state_dags: + prev_state_events = [invite.event_id] # or perhaps this is a remote room that a local user has knocked on elif current_membership_type == Membership.KNOCK: @@ -1210,6 +1230,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent=require_consent, outlier=outlier, origin_server_ts=origin_server_ts, + prev_state_events=prev_state_events, delay_id=delay_id, ) @@ -2108,10 +2129,21 @@ class RoomMemberMasterHandler(RoomMemberHandler): # # the prev_events consist solely of the previous membership event. prev_event_ids = [previous_membership_event.event_id] - auth_event_ids = ( - list(previous_membership_event.auth_event_ids()) + prev_event_ids - ) + auth_event_ids = None + # Authorise the leave by referencing the previous membership + prev_state_event_ids = None + if previous_membership_event.room_version.msc4242_state_dags: + prev_state_event_ids = [ + previous_membership_event.event_id, + ] + else: + auth_event_ids = ( + list(previous_membership_event.auth_event_ids()) + prev_event_ids + ) + # State DAG rooms should not have auth events specified + # Normal rooms should not have prev state event IDs specified + assert not (prev_state_event_ids is not None and auth_event_ids is not None) # Try several times, it could fail with PartialStateConflictError # in handle_new_client_event, cf comment in except block. max_retries = 5 @@ -2127,6 +2159,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, outlier=True, + prev_state_events=prev_state_event_ids, ) context = await unpersisted_context.persist(event) event.internal_metadata.out_of_band_membership = True diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index b209404cd1..0774b6ed40 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -96,6 +96,10 @@ from synapse.rest.admin.statistics import ( LargestRoomsStatistics, UserMediaStatisticsRestServlet, ) +from synapse.rest.admin.user_reports import ( + UserReportDetailRestServlet, + UserReportsRestServlet, +) from synapse.rest.admin.username_available import UsernameAvailableRestServlet from synapse.rest.admin.users import ( AccountDataRestServlet, @@ -312,6 +316,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LargestRoomsStatistics(hs).register(http_server) EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) + UserReportsRestServlet(hs).register(http_server) + UserReportDetailRestServlet(hs).register(http_server) AccountDataRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server) MakeRoomAdminRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/user_reports.py b/synapse/rest/admin/user_reports.py new file mode 100644 index 0000000000..119dc86517 --- /dev/null +++ b/synapse/rest/admin/user_reports.py @@ -0,0 +1,173 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2026 Element Creations, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + + +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING + +from synapse.api.constants import Direction +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class UserReportsRestServlet(RestServlet): + """ + List all reported users that are known to the homeserver. Results are returned + in a dictionary containing report information. Supports pagination. + The requester must have administrator access in Synapse. + + GET /_synapse/admin/v1/user_reports + returns: + 200 OK with list of reports if success otherwise an error. + + Args: + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `dir` can be used to define the order of results. + The `user_id` query parameter filters by the user ID of the reporter of the target user. + The `target_user_id` query parameter filters by user id of the target user. + Returns: + A list of user reprots and an integer representing the total number of user + reports that exist given this query + """ + + PATTERNS = admin_patterns("/user_reports$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + + async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + direction = parse_enum(request, "dir", Direction, Direction.BACKWARDS) + user_id = parse_string(request, "user_id") + target_user_id = parse_string(request, "target_user_id") + + if start < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "The start parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "The limit parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + user_reports, total = await self._store.get_user_reports_paginate( + start, limit, direction, user_id, target_user_id + ) + ret = {"user_reports": user_reports, "total": total} + if (start + limit) < total: + ret["next_token"] = start + len(user_reports) + + return HTTPStatus.OK, ret + + +class UserReportDetailRestServlet(RestServlet): + """ + Get a specific user report that is known to the homeserver. Results are returned + in a dictionary containing report information. + The requester must have administrator access in Synapse. + + GET /_synapse/admin/v1/user_reports/ + returns: + 200 OK with details report if success otherwise an error. + + Args: + The parameter `report_id` is the ID of the user report in the database. + Returns: + JSON blob of information about the user report + """ + + PATTERNS = admin_patterns("/user_reports/(?P[^/]*)$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + + async def on_GET( + self, request: SynapseRequest, report_id: str + ) -> tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + message = ( + "The report_id parameter must be a string representing a positive integer." + ) + try: + resolved_report_id = int(report_id) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if resolved_report_id < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + ret = await self._store.get_user_report(resolved_report_id) + if not ret: + raise NotFoundError("User report not found") + + id, received_ts, target_user_id, user_id, reason = ret + response = { + "id": id, + "received_ts": received_ts, + "target_user_id": target_user_id, + "user_id": user_id, + "reason": reason, + } + + return HTTPStatus.OK, response + + async def on_DELETE( + self, request: SynapseRequest, report_id: str + ) -> tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + message = ( + "The report_id parameter must be a string representing a positive integer." + ) + try: + resolved_report_id = int(report_id) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if resolved_report_id < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if await self._store.delete_user_report(resolved_report_id): + return HTTPStatus.OK, {} + + raise NotFoundError("User report not found") diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 4db3b01576..15f58acb95 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -253,6 +253,7 @@ class DownloadResource(RestServlet): ), send_cors=True, ) + return set_cors_headers(request) set_corp_headers(request) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a92233c863..2f0e3f2c3e 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -37,7 +37,7 @@ from prometheus_client import Counter, Histogram from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions -from synapse.events import EventBase +from synapse.events import EventBase, FrozenEventVMSC4242 from synapse.events.snapshot import ( EventContext, UnpersistedEventContext, @@ -239,31 +239,6 @@ class StateHandler: ) return await ret.get_state(self._state_storage_controller, state_filter) - async def get_current_user_ids_in_room( - self, room_id: str, latest_event_ids: StrCollection - ) -> set[str]: - """ - Get the users IDs who are currently in a room. - - Note: This is much slower than using the equivalent method - `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`, - so this should only be used when wanting the users at a particular point - in the room. - - Args: - room_id: The ID of the room. - latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. - Returns: - Set of user IDs in the room. - """ - - assert latest_event_ids is not None - - logger.debug("calling resolve_state_groups from get_current_user_ids_in_room") - entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - state = await entry.get_state(self._state_storage_controller, StateFilter.all()) - return await self.store.get_joined_user_ids_from_state(room_id, state) - async def get_hosts_in_room_at_events( self, room_id: str, event_ids: StrCollection ) -> frozenset[str]: @@ -303,7 +278,8 @@ class StateHandler: membership events. `False` if `state_ids_before_event` is the full state. `None` when `state_ids_before_event` is not provided. In this case, the - flag will be calculated based on `event`'s prev events. + flag will be calculated based on `event`'s `prev_events` or `prev_state_events` + for state DAG rooms. state_group_before_event: the current state group at the time of event, if known Returns: @@ -337,7 +313,11 @@ class StateHandler: # (This is slightly racy - the prev-events might get fixed up before we use # their states - but I don't think that really matters; it just means we # might redundantly recalculate the state for this event later.) - prev_event_ids = event.prev_event_ids() + prev_event_ids = frozenset( + event.prev_state_events + if isinstance(event, FrozenEventVMSC4242) + else event.prev_event_ids() + ) incomplete_prev_events = await self.store.get_partial_state_events( prev_event_ids ) @@ -355,7 +335,7 @@ class StateHandler: entry = await self.resolve_state_groups_for_events( event.room_id, - event.prev_event_ids(), + prev_event_ids, await_full_state=False, ) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 2948227807..7cc6a39639 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -35,6 +35,7 @@ from typing import ( Generic, Iterable, TypeVar, + cast, ) import attr @@ -43,7 +44,9 @@ from prometheus_client import Counter, Histogram from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.events import EventBase +from synapse.api.errors import SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase, FrozenEventVMSC4242, event_exists_in_state_dag from synapse.events.snapshot import EventContext, EventPersistencePair from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable @@ -68,6 +71,7 @@ from synapse.types import ( from synapse.types.state import StateFilter from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results from synapse.util.metrics import Measure +from synapse.util.stringutils import shortstr if TYPE_CHECKING: from synapse.server import HomeServer @@ -111,6 +115,14 @@ stale_forward_extremities_counter = Histogram( buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), ) +# The number of forward extremities for each new event. +msc4242_state_dag_forward_extremities_counter = Histogram( + "synapse_storage_msc4242_state_dag_forward_extremities_persisted", + "Number of forward extremities for each new event in the state DAG", + labelnames=[SERVER_NAME_LABEL], + buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + state_resolutions_during_persistence = Counter( "synapse_storage_events_state_resolutions_during_persistence", "Number of times we had to do state res to calculate new current state", @@ -529,7 +541,15 @@ class EventsPersistenceStorageController: Returns: map from (type, state_key) to event id for the current state in the room """ - latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id) + room_version = await self.main_store.get_room_version_id(room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + if room_version_obj.msc4242_state_dags: + latest_event_ids = await self.main_store.get_state_dag_extremities(room_id) + else: + latest_event_ids = await self.main_store.get_latest_event_ids_in_room( + room_id + ) + state_groups = set( ( await self.main_store._get_state_group_for_events(latest_event_ids) @@ -551,7 +571,6 @@ class EventsPersistenceStorageController: # Avoid a circular import. from synapse.state import StateResolutionStore - room_version = await self.main_store.get_room_version_id(room_id) res = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, @@ -615,28 +634,52 @@ class EventsPersistenceStorageController: for x in range(0, len(events_and_contexts), 100) ] + # Get the room version for the first event. This room version is the same for all events + # as events_and_contexts is all for one room. + assert len(events_and_contexts) > 0 + room_version = events_and_contexts[0][0].room_version + for chunk in chunks: # We can't easily parallelize these since different chunks # might contain the same event. :( new_forward_extremities = None state_delta_for_room = None + new_state_dag_extrems = None if not backfilled: - with Measure( - self._clock, - name="_calculate_state_and_extrem", - server_name=self.server_name, - ): - # Work out the new "current state" for the room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - ( - new_forward_extremities, - state_delta_for_room, - ) = await self._calculate_new_forward_extremities_and_state_delta( - room_id, chunk - ) + if room_version.msc4242_state_dags: + with Measure( + self._clock, + name="_process_state_dag_forward_extremities_and_state_delta", + server_name=self.server_name, + ): + assert all( + isinstance(ev, FrozenEventVMSC4242) for ev, _ in chunk + ) + ( + new_forward_extremities, # for prev_events + state_delta_for_room, # for state groups + new_state_dag_extrems, # for prev_state_events + ) = await self._process_state_dag_forward_extremities_and_state_delta( + room_id, + cast(list[tuple[FrozenEventVMSC4242, EventContext]], chunk), + ) + else: + with Measure( + self._clock, + name="_calculate_state_and_extrem", + server_name=self.server_name, + ): + # Work out the new "current state" for the room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + ( + new_forward_extremities, + state_delta_for_room, + ) = await self._calculate_new_forward_extremities_and_state_delta( + room_id, chunk + ) with Measure( self._clock, @@ -666,6 +709,7 @@ class EventsPersistenceStorageController: use_negative_stream_ordering=backfilled, inhibit_local_membership_updates=backfilled, new_event_links=new_event_links, + new_state_dag_forward_extremities=new_state_dag_extrems, ) return replaced_events @@ -793,6 +837,216 @@ class EventsPersistenceStorageController: return (new_forward_extremities, delta) + async def _process_state_dag_forward_extremities_and_state_delta( + self, + room_id: str, + event_contexts: list[tuple[FrozenEventVMSC4242, EventContext]], + ) -> tuple[set[str] | None, DeltaState | None, set[str] | None]: + """Process the forwards extremities for state DAG rooms. + Returns: + - the new room dag extremities which should be written when these events are persisted. + - the state delta for the room, if applicable. + - the new state dag extremities which should be written when these events are persisted. + + NB: this does not write them because if it did, new events may see them _before_ the events + get persisted, causing failures in retrieving state groups. + """ + # Update forward extremities + # ...for the state DAG + existing_state_dag_fwd_extrems = ( + await self.main_store.get_state_dag_extremities(room_id) + ) + new_state_dag_fwd_extrems = await self._calculate_new_state_dag_extremities( + room_id, + existing_state_dag_fwd_extrems, + event_contexts, + ) + # ...and the room DAG + existing_room_dag_fwd_extrems = ( + await self.main_store.get_latest_event_ids_in_room(room_id) + ) + new_room_dag_fwd_extrems = await self._calculate_new_extremities( + room_id, + cast(list[EventPersistencePair], event_contexts), + existing_room_dag_fwd_extrems, + ) + assert new_room_dag_fwd_extrems, ( + f"No room dag forward extremities left in room {room_id}!" + ) + + # See if we need to calculate a state delta + if new_state_dag_fwd_extrems == existing_state_dag_fwd_extrems: + # No change in state extremities, so no new state to calculate + return new_room_dag_fwd_extrems, None, new_state_dag_fwd_extrems + + with Measure( + self._clock, + name="persist_events.state_dag.get_new_state_after_events", + server_name=self.server_name, + ): + (current_state, delta_ids, _) = await self._get_new_state_after_events( + room_id, + cast(list[EventPersistencePair], event_contexts), + existing_state_dag_fwd_extrems, + new_state_dag_fwd_extrems, + # do not prune forward extremities in the state DAG + # else we lose eventual delivery + should_prune=False, + ) + + # Following logic cargoculted from _calculate_new_forward_extremities_and_state_delta + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + delta = None + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + delta = DeltaState([], delta_ids) + elif current_state is not None: + with Measure( + self._clock, + name="persist_events.calculate_state_delta", + server_name=self.server_name, + ): + delta = await self._calculate_state_delta(room_id, current_state) + + if delta: + # If we have a change of state then lets check + # whether we're actually still a member of the room, + # or if our last user left. If we're no longer in + # the room then we delete the current state and + # extremities. + is_still_joined = await self._is_server_still_joined( + room_id, + cast(list[EventPersistencePair], event_contexts), + delta, + ) + if not is_still_joined: + logger.info("Server no longer in room %s", room_id) + delta.no_longer_in_room = True + + return new_room_dag_fwd_extrems, delta, new_state_dag_fwd_extrems + + async def _calculate_new_state_dag_extremities( + self, + room_id: str, + existing_fwd_extrems: frozenset[str], + event_contexts: list[tuple[FrozenEventVMSC4242, EventContext]], + ) -> set[str]: + """Calculate the new state dag forward extremities. Modifies existing_fwd_extrems. + + Assumes that event_contexts are only state events which should be in the state DAG. + + Raises: + SynapseError: if the new events include unknown prev_state_events + AssertionError: if there are no state DAG forward extremities remaining in the room + """ + # Events are always processed in causal order without any gaps in the DAG + # (prev_state_events are always known), guaranteeing that processed events have a path to the + # create event. This is an emergent property of state DAGs as asserting that there is a path + # to the create event every time we insert an event would be prohibitively expensive. + # This is similar to how doubly-linked lists can potentially not refer to previous items correctly + # without verifying the list's integrity, but doing it on every insert is too expensive. + + # filter out events which don't belong in the state dag. + new_state_events_contexts = [ + (e, ctx) for e, ctx in event_contexts if event_exists_in_state_dag(e) + ] + if len(new_state_events_contexts) == 0: + # if there are no state events being persisted, then the fwd extremities of the state dag + # do not change. + return set(existing_fwd_extrems) + + # This logic is very similar to _calculate_new_extremities with a few key differences: + # - We do not "Remove any events which are prev_events of any existing events." because the + # state DAG mandates that events are processed in causal order, so there MUST NOT be any + # existing, processed events which have the to-be-persisted events as prev_state_events. + # - We don't care if they are an "outlier" in the main room dag, so long as they AREN'T + # an outlier on the state dag, which this function checks, so we don't check outlier-ness. + # - We allow *soft-failed* events to become forward extremities, as per the MSC. We do not + # allow *rejected* events to become forward extremities though. + + rejected_events = [ev for ev, ctx in new_state_events_contexts if ctx.rejected] + new_state_events = [ + ev for ev, ctx in new_state_events_contexts if not ctx.rejected + ] + # We want to check that we are not missing any prev_state_events. + # To do this, we include rejected events in this check because other events may point to them. + # If we didn't include them, we might incorrectly say we are missing events when we are not. + all_new_state_events = set(rejected_events + new_state_events) + + # First, verify that we know all prev_state_events. If we fail this check then we don't have + # a complete DAG and that is bad, so bail out. + + # Start with them all missing. + missing_prev_state_events = { + e_id for event in all_new_state_events for e_id in event.prev_state_events + } + + # remove prev events which appear in all_events + missing_prev_state_events.difference_update( + event.event_id for event in all_new_state_events + ) + # the rest of these events should be present in the DB. Some of them may be forward extremities, + # some may not be, that's ok. + seen_events = await self.main_store.have_seen_events( + room_id, + missing_prev_state_events, + ) + missing_prev_state_events.difference_update(seen_events) + + if len(missing_prev_state_events) > 0: + logger.error( + "_calculate_new_state_dag_extremities: missing the following prev_state_events in room %s : %s", + room_id, + missing_prev_state_events, + ) + logger.error( + "_calculate_new_state_dag_extremities: was handling %s", + shortstr([ev.event_id for ev in all_new_state_events]), + ) + raise SynapseError( + code=500, + msg=f"missing {len(missing_prev_state_events)} prev_state_events in room {room_id}", + ) + + # Now calculate the forward extremities. + + # start with the existing forward extremities + result = set(existing_fwd_extrems) + + # add all the new events to the list + result.update(event.event_id for event in new_state_events) + + # Now remove all events which are prev_state_events of any of the new events + result.difference_update( + e_id for event in new_state_events for e_id in event.prev_state_events + ) + + # Finally handle the case where the new events have rejected/soft-failed `prev_state_events`. + # If they do we need to remove them and their `prev_state_events`, + # otherwise we end up with dangling extremities. + # Specifically, this handles the case where (F=fwd extrem, SF=soft-failed, N=new event) + # F <-- SF <-- SF <-- N + # where we want to remove F as a forward extremity and replace with N. + existing_prevs = await self.persist_events_store._get_prevs_before_rejected( + (e_id for event in new_state_events for e_id in event.prev_state_events), + include_soft_failed=False, + ) + result.difference_update(existing_prevs) + + # We only update metrics for events that change forward extremities + if result != existing_fwd_extrems: + msc4242_state_dag_forward_extremities_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).observe(len(result)) + + # There should always be at least one forward extremity. + assert result, f"No state dag forward extremities left in room {room_id}!" + return result + async def _calculate_new_extremities( self, room_id: str, @@ -859,6 +1113,7 @@ class EventsPersistenceStorageController: events_context: list[EventPersistencePair], old_latest_event_ids: AbstractSet[str], new_latest_event_ids: set[str], + should_prune: bool = True, ) -> tuple[StateMap[str] | None, StateMap[str] | None, set[str]]: """Calculate the current state dict after adding some new events to a room @@ -873,9 +1128,15 @@ class EventsPersistenceStorageController: old_latest_event_ids: the old forward extremities for the room. - new_latest_event_ids : + new_latest_event_ids: the new forward extremities for the room. + should_prune: + if true, attempt to prune the forward extremities. + Pruning means we will not communicate some new events to other servers, + which can compromise eventual delivery, so graphs which are fully synchronised + e.g. state DAGs should not prune. + Returns: Returns a tuple of two state maps and a set of new forward extremities. @@ -1015,7 +1276,7 @@ class EventsPersistenceStorageController: # If the returned state matches the state group of one of the new # forward extremities then we check if we are able to prune some state # extremities. - if res.state_group and res.state_group in new_state_groups: + if should_prune and res.state_group and res.state_group in new_state_groups: new_latest_event_ids = await self._prune_extremities( room_id, new_latest_event_ids, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e9ecf46411..8670d68f38 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -79,6 +79,19 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" +# Background update name for adding an index on +# `device_lists_changes_in_room.inserted_ts`. +BG_UPDATE_ADD_INSERTED_TS_INDEX = "device_lists_changes_in_room_inserted_ts_idx" + + +# Prunes entries out of the `device_lists_changes_in_room` table that are more +# than this old. +PRUNE_DEVICE_LISTS_CHANGES_IN_ROOM_AGE = Duration(days=30) + +# The number of rows to delete at once when pruning old entries out of the +# `device_lists_changes_in_room` table. +PRUNE_DEVICE_LISTS_BATCH_SIZE = 1000 + class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): _device_list_id_gen: MultiWriterIdGenerator @@ -194,6 +207,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): self.clock.looping_call( self._prune_old_outbound_device_pokes, Duration(hours=1) ) + self.clock.looping_call( + self._prune_device_lists_changes_in_room, + Duration(hours=1), + ) def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] @@ -1143,6 +1160,35 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): The set of user_ids whose devices have changed since `from_key` (exclusive) until `to_key` (inclusive). """ + return { + user_id + for user_id, _ in await self.get_device_changes_for_users( + from_key, user_ids, to_key + ) + } + + @cancellable + async def get_device_changes_for_users( + self, + from_key: MultiWriterStreamToken, + user_ids: Collection[str], + to_key: MultiWriterStreamToken | None = None, + ) -> set[tuple[str, str]]: + """Get set of user/device ID tuple whose devices have changed since `from_key` that + are in the given list of user_ids. + + Args: + from_key: The minimum device lists stream token to query device list changes for, + exclusive. + user_ids: If provided, only check if these users have changed their device lists. + Otherwise changes from all users are returned. + to_key: The maximum device lists stream token to query device list changes for, + inclusive. If None then no upper limit is applied. + + Returns: + The set of user/device ID tuples whose devices have changed since `from_key` + (exclusive) until `to_key` (inclusive). + """ # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. user_ids_to_check = self._device_list_stream_cache.get_entities_changed( @@ -1156,18 +1202,18 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): if to_key is None: to_key = self.get_device_stream_token() - def _get_users_whose_devices_changed_txn( + def get_device_changes_for_users_txn( txn: LoggingTransaction, from_key: MultiWriterStreamToken, to_key: MultiWriterStreamToken, - ) -> set[str]: + ) -> set[tuple[str, str]]: sql = """ - SELECT user_id, stream_id, instance_name + SELECT user_id, device_id, stream_id, instance_name FROM device_lists_stream WHERE ? < stream_id AND stream_id <= ? AND %s """ - changes: set[str] = set() + changes: set[tuple[str, str]] = set() # Query device changes with a batch of users at a time for chunk in batch_iter(user_ids_to_check, 100): @@ -1179,8 +1225,8 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): [from_key.stream, to_key.get_max_stream_pos()] + args, ) changes.update( - user_id - for (user_id, stream_id, instance_name) in txn + (user_id, device_id) + for (user_id, device_id, stream_id, instance_name) in txn if MultiWriterStreamToken.is_stream_position_in_range( low=from_key, high=to_key, @@ -1192,8 +1238,8 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return changes return await self.db_pool.runInteraction( - "get_users_whose_devices_changed", - _get_users_whose_devices_changed_txn, + "get_device_changes_for_users", + get_device_changes_for_users_txn, from_key, to_key, ) @@ -1699,17 +1745,22 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return devices - @cached() - async def _get_min_device_lists_changes_in_room(self) -> int: - """Returns the minimum stream ID that we have entries for - `device_lists_changes_in_room` + def _get_max_pruned_device_lists_changes_in_room_txn( + self, txn: LoggingTransaction + ) -> int: + """Returns the maximum stream ID that has been pruned from + `device_lists_changes_in_room`. + + Any queries for stream IDs less than this value cannot be answered + completely, as the data has been deleted. """ - return await self.db_pool.simple_select_one_onecol( - table="device_lists_changes_in_room", + return self.db_pool.simple_select_one_onecol_txn( + txn, + table="device_lists_changes_in_room_max_pruned_stream_id", keyvalues={}, - retcol="COALESCE(MIN(stream_id), 0)", - desc="get_min_device_lists_changes_in_room", + retcol="stream_id", + allow_none=False, ) @cancellable @@ -1728,55 +1779,54 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): if not room_ids: return set() - min_stream_id = await self._get_min_device_lists_changes_in_room() - - # Return early if there are no rows to process in device_lists_changes_in_room - if min_stream_id > from_token.stream: - return None - changed_room_ids = self._device_list_room_stream_cache.get_entities_changed( room_ids, from_token.stream ) if not changed_room_ids: return set() - sql = """ - SELECT user_id, stream_id, instance_name - FROM device_lists_changes_in_room - WHERE {clause} AND stream_id > ? AND stream_id <= ? - """ - def _get_device_list_changes_in_rooms_txn( txn: LoggingTransaction, - chunk: list[str], - ) -> set[str]: - clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", chunk + ) -> set[str] | None: + # Check if the from_token is too old (i.e. data has been pruned). + max_pruned_stream_id = ( + self._get_max_pruned_device_lists_changes_in_room_txn(txn) ) - args.append(from_token.stream) - args.append(to_token.get_max_stream_pos()) + if max_pruned_stream_id > from_token.stream: + return None - txn.execute(sql.format(clause=clause), args) - return { - user_id - for (user_id, stream_id, instance_name) in txn - if MultiWriterStreamToken.is_stream_position_in_range( - low=from_token, - high=to_token, - instance_name=instance_name, - pos=stream_id, + changes: set[str] = set() + + for chunk in batch_iter(changed_room_ids, 1000): + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", chunk ) - } + args.append(from_token.stream) + args.append(to_token.get_max_stream_pos()) - changes = set() - for chunk in batch_iter(changed_room_ids, 1000): - changes |= await self.db_pool.runInteraction( - "get_device_list_changes_in_rooms", - _get_device_list_changes_in_rooms_txn, - chunk, - ) + sql = f""" + SELECT user_id, stream_id, instance_name + FROM device_lists_changes_in_room + WHERE {clause} AND stream_id > ? AND stream_id <= ? + """ + txn.execute(sql, args) + changes.update( + user_id + for (user_id, stream_id, instance_name) in txn + if MultiWriterStreamToken.is_stream_position_in_range( + low=from_token, + high=to_token, + instance_name=instance_name, + pos=stream_id, + ) + ) - return changes + return changes + + return await self.db_pool.runInteraction( + "get_device_list_changes_in_rooms", + _get_device_list_changes_in_rooms_txn, + ) async def get_all_device_list_changes(self, from_id: int, to_id: int) -> set[str]: """Return the set of rooms where devices have changed since the given @@ -1785,46 +1835,66 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): Will raise an exception if the given stream ID is too old. """ - min_stream_id = await self._get_min_device_lists_changes_in_room() - - if min_stream_id > from_id: - raise Exception("stream ID is too old") - - sql = """ - SELECT DISTINCT room_id FROM device_lists_changes_in_room - WHERE stream_id > ? AND stream_id <= ? - """ - def _get_all_device_list_changes_txn( txn: LoggingTransaction, - ) -> set[str]: + ) -> set[str] | None: + # Check if the from_token is too old (i.e. data has been pruned). + max_pruned_stream_id = ( + self._get_max_pruned_device_lists_changes_in_room_txn(txn) + ) + if max_pruned_stream_id > from_id: + logger.warning( + "Given stream ID is too old %d < %d", + from_id, + max_pruned_stream_id, + ) + return None + + sql = """ + SELECT DISTINCT room_id FROM device_lists_changes_in_room + WHERE stream_id > ? AND stream_id <= ? + """ + txn.execute(sql, (from_id, to_id)) return {room_id for (room_id,) in txn} - return await self.db_pool.runInteraction( + room_ids = await self.db_pool.runInteraction( "get_all_device_list_changes", _get_all_device_list_changes_txn, ) + if room_ids is None: + raise Exception(f"Given stream ID is too old {from_id}") + + return room_ids + async def get_device_list_changes_in_room( self, room_id: str, min_stream_id: int - ) -> Collection[tuple[str, str]]: + ) -> Collection[tuple[str, str]] | None: """Get all device list changes that happened in the room since the given stream ID. Returns: Collection of user ID/device ID tuples of all devices that have - changed - """ - - sql = """ - SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room - WHERE room_id = ? AND stream_id > ? + changed, or None if the given stream ID is too old and so a complete + list cannot be calculated. """ def get_device_list_changes_in_room_txn( txn: LoggingTransaction, - ) -> Collection[tuple[str, str]]: + ) -> Collection[tuple[str, str]] | None: + # Check if the from_token is too old (i.e. data has been pruned). + max_pruned_stream_id = ( + self._get_max_pruned_device_lists_changes_in_room_txn(txn) + ) + if max_pruned_stream_id > min_stream_id: + return None + + sql = """ + SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room + WHERE room_id = ? AND stream_id > ? + """ + txn.execute(sql, (room_id, min_stream_id)) return cast(Collection[tuple[str, str]], txn.fetchall()) @@ -2160,6 +2230,8 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): encoded_context = json_encoder.encode(context) + now = self.clock.time_msec() + # The `device_lists_changes_in_room.stream_id` column matches the # corresponding `stream_id` of the update in the `device_lists_stream` # table, i.e. all rows persisted for the same device update will have @@ -2175,6 +2247,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): "instance_name", "converted_to_destinations", "opentracing_context", + "inserted_ts", ), values=[ ( @@ -2186,6 +2259,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): # We only need to calculate outbound pokes for local users not self.hs.is_mine_id(user_id), encoded_context, + now, ) for room_id in room_ids for device_id, stream_id in zip(device_ids, stream_ids) @@ -2401,6 +2475,154 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): desc="set_device_change_last_converted_pos", ) + @wrap_as_background_process("prune_device_lists_changes_in_room") + async def _prune_device_lists_changes_in_room(self) -> None: + """Delete old entries out of the `device_lists_changes_in_room`, so that + the table doesn't grow indefinitely. + """ + + # Let's only do this pruning if the index on inserted_ts has been + # created, otherwise this query will be very inefficient. + has_index_been_created = ( + await self.db_pool.updates.has_completed_background_update( + BG_UPDATE_ADD_INSERTED_TS_INDEX + ) + ) + if not has_index_been_created: + return + + prune_before_ts = ( + self.clock.time_msec() - PRUNE_DEVICE_LISTS_CHANGES_IN_ROOM_AGE.as_millis() + ) + + # Get stream ID corresponding to the prune_before_ts timestamp. We can + # delete all rows with a stream ID less than or equal to this, as they + # will be older than the cutoff. + # + # Some rows will have a NULL inserted_ts (due to being inserted before + # the column was added), but we can assume that the timestamp will + # monotonically increase with stream ID, so we can safely ignore those + # rows when calculating the cutoff stream ID. This means that we may end + # up keeping some rows with a non-NULL inserted_ts that are older than + # the cutoff, but that's better than accidentally deleting rows that are + # newer than the cutoff. + cutoff_sql = """ + SELECT stream_id FROM device_lists_changes_in_room + WHERE inserted_ts <= ? AND inserted_ts IS NOT NULL + ORDER BY inserted_ts DESC + LIMIT 1 + """ + + def get_prune_before_stream_id_txn(txn: LoggingTransaction) -> int | None: + txn.execute(cutoff_sql, (prune_before_ts,)) + row = txn.fetchone() + return row[0] if row else None + + prune_before_stream_id = await self.db_pool.runInteraction( + "prune_device_lists_changes_in_room_get_stream_id", + get_prune_before_stream_id_txn, + ) + + if prune_before_stream_id is None: + return + + # Get the max stream ID in the table so we avoid deleting it. We need + # to keep the latest row so that we can calculate the maximum stream ID + # used. + max_stream_id = await self.db_pool.simple_select_one_onecol( + table="device_lists_changes_in_room", + keyvalues={}, + retcol="MAX(stream_id)", + desc="prune_device_lists_changes_in_room_get_max_stream_id", + ) + if prune_before_stream_id >= max_stream_id: + prune_before_stream_id = max_stream_id - 1 + + logger.debug( + "Pruning device_lists_changes_in_room before stream ID %d (timestamp %d)", + prune_before_stream_id, + prune_before_ts, + ) + + # Now delete all rows with stream_id less than the + # prune_before_stream_id. + # + # We also delete in batches to avoid massive churn when initially + # clearing out all the old entries. + # + # We set a minimum stream ID so that when we delete in batches the + # database doesn't have to scan through all the (dead) tuples that were just + # deleted to find the next batch to delete. + + # The minimum stream ID to delete in the next batch, c.f. comment above. + # We default to 0 here as that is less than all possible stream IDs. + min_stream_id = 0 + + def prune_device_lists_changes_in_room_txn(txn: LoggingTransaction) -> int: + nonlocal min_stream_id + + delete_sql = """ + DELETE FROM device_lists_changes_in_room + WHERE stream_id IN ( + SELECT stream_id FROM device_lists_changes_in_room + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + ) + RETURNING stream_id + """ + txn.execute( + delete_sql, + (min_stream_id, prune_before_stream_id, PRUNE_DEVICE_LISTS_BATCH_SIZE), + ) + + # We can't use rowcount as that is incorrect on SQLite when using + # RETURNING. + num_deleted = 0 + for row in txn: + num_deleted += 1 + min_stream_id = max(min_stream_id, row[0]) + + if num_deleted: + # Update the max pruned stream ID tracking table so that the + # safety check knows data up to this point has been deleted. + self.db_pool.simple_update_one_txn( + txn, + table="device_lists_changes_in_room_max_pruned_stream_id", + keyvalues={}, + updatevalues={"stream_id": min_stream_id}, + ) + + return num_deleted + + progress_num_rows_deleted = 0 + while True: + batch_deleted = await self.db_pool.runInteraction( + "prune_device_lists_changes_in_room", + prune_device_lists_changes_in_room_txn, + ) + + finished = batch_deleted < PRUNE_DEVICE_LISTS_BATCH_SIZE + + progress_num_rows_deleted += batch_deleted + + # Periodically report progress in the logs. We do this either when + # we've deleted a significant number of rows or when we've finished + # deleting all rows in this round. + if finished or progress_num_rows_deleted > 10000: + logger.info( + "Pruned %d rows from device_lists_changes_in_room", + progress_num_rows_deleted, + ) + progress_num_rows_deleted = 0 + + if finished: + break + + # Sleep for a short time to avoid hammering the database too much if + # there are a lot of rows to delete. + await self.clock.sleep(Duration(milliseconds=100)) + class DeviceBackgroundUpdateStore(SQLBaseStore): _instance_name: str @@ -2459,6 +2681,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): columns=["room_id", "stream_id"], ) + # Add indexes to speed up pruning of device_lists_changes_in_room + self.db_pool.updates.register_background_index_update( + BG_UPDATE_ADD_INSERTED_TS_INDEX, + index_name="device_lists_changes_in_room_inserted_ts_idx", + table="device_lists_changes_in_room", + columns=["inserted_ts"], + where_clause="inserted_ts IS NOT NULL", + ) + async def _drop_device_list_streams_non_unique_indexes( self, progress: JsonDict, batch_size: int ) -> int: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index cc7083b605..415926eb0a 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1493,6 +1493,15 @@ class EventFederationWorkerStore( ) return frozenset(event_ids) + async def get_state_dag_extremities(self, room_id: str) -> frozenset[str]: + event_ids = await self.db_pool.simple_select_onecol( + table="msc4242_state_dag_forward_extremities", + keyvalues={"room_id": room_id}, + retcol="event_id", + desc="get_state_dag_extremities", + ) + return frozenset(event_ids) + async def get_min_depth(self, room_id: str) -> int | None: """For the given room, get the minimum depth we have seen for it.""" return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 6d3bc15777..12c918eca6 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -48,7 +48,9 @@ from synapse.api.errors import PartialStateConflictError from synapse.api.room_versions import RoomVersions from synapse.events import ( EventBase, + FrozenEventVMSC4242, StrippedStateEvent, + event_exists_in_state_dag, is_creator, relation_from_event, ) @@ -295,6 +297,7 @@ class PersistEventsStore: new_event_links: dict[str, NewEventChainLinks], use_negative_stream_ordering: bool = False, inhibit_local_membership_updates: bool = False, + new_state_dag_forward_extremities: set[str] | None = None, ) -> None: """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -315,6 +318,8 @@ class PersistEventsStore: from being updated by these events. This should be set to True for backfilled events because backfilled events in the past do not affect the current local state. + new_state_dag_forward_extremities: A set of event IDs that are the new forward + extremities for the state DAG for this room. MSC4242 only. Returns: Resolves when the events have been persisted @@ -379,6 +384,7 @@ class PersistEventsStore: new_forward_extremities=new_forward_extremities, new_event_links=new_event_links, sliding_sync_table_changes=sliding_sync_table_changes, + new_state_dag_forward_extremities=new_state_dag_forward_extremities, ) persist_event_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc( len(events_and_contexts) @@ -962,8 +968,10 @@ class PersistEventsStore: return results - async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> set[str]: - """Get soft-failed ancestors to remove from the extremities. + async def _get_prevs_before_rejected( + self, event_ids: Iterable[str], include_soft_failed: bool = True + ) -> set[str]: + """Get soft-failed/rejected ancestors to remove from the extremities. Given a set of events, find all those that have been soft-failed or rejected. Returns those soft failed/rejected events and their prev @@ -976,7 +984,8 @@ class PersistEventsStore: Args: event_ids: Events to find prev events for. Note that these must have already been persisted. - + include_soft_failed: Soft-failed events are included in the search. If false, only + rejected events are included. Returns: The previous events. """ @@ -1016,7 +1025,7 @@ class PersistEventsStore: continue soft_failed = db_to_json(metadata).get("soft_failed") - if soft_failed or rejected: + if (include_soft_failed and soft_failed) or rejected: to_recursively_check.append(prev_event_id) existing_prevs.add(prev_event_id) @@ -1038,6 +1047,7 @@ class PersistEventsStore: new_forward_extremities: set[str] | None, new_event_links: dict[str, NewEventChainLinks], sliding_sync_table_changes: SlidingSyncTableChanges | None, + new_state_dag_forward_extremities: set[str] | None = None, ) -> None: """Insert some number of room events into the necessary database tables. @@ -1146,6 +1156,11 @@ class PersistEventsStore: max_stream_order=max_stream_order, ) + if new_state_dag_forward_extremities: + self._set_state_dag_extremities_txn( + txn, room_id, new_state_dag_forward_extremities + ) + self._persist_transaction_ids_txn(txn, events_and_contexts) # Insert into event_to_state_groups. @@ -2475,6 +2490,29 @@ class PersistEventsStore: ], ) + def _set_state_dag_extremities_txn( + self, txn: LoggingTransaction, room_id: str, new_extrems: Collection[str] + ) -> None: + self.db_pool.simple_delete_txn( + txn, + table="msc4242_state_dag_forward_extremities", + keyvalues={ + "room_id": room_id, + }, + ) + self.db_pool.simple_insert_many_txn( + txn, + table="msc4242_state_dag_forward_extremities", + keys=("room_id", "event_id"), + values=[ + ( + room_id, + event_id, + ) + for event_id in new_extrems + ], + ) + @classmethod def _filter_events_and_contexts_for_duplicates( cls, events_and_contexts: list[EventPersistencePair] @@ -2859,6 +2897,12 @@ class PersistEventsStore: self._handle_event_relations(txn, event) + if event.room_version.msc4242_state_dags and event_exists_in_state_dag( + event + ): + assert isinstance(event, FrozenEventVMSC4242) + self._store_state_dag_edges(txn, event) + # Store the labels for this event. labels = event.content.get(EventContentFields.LABELS) if labels: @@ -2935,6 +2979,36 @@ class PersistEventsStore: txn.async_call_after(external_prefill) txn.call_after(local_prefill) + def _store_state_dag_edges( + self, txn: LoggingTransaction, event: FrozenEventVMSC4242 + ) -> None: + # the create event has no edge but we still need to persist it as get_state_dag just + # yanks all rows in this table. It's a bit gross to store NULL as the prev_state_event_id + # though. + if len(event.prev_state_events) == 0 and event.type == EventTypes.Create: + self.db_pool.simple_insert_txn( + txn, + table="msc4242_state_dag_edges", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + "prev_state_event_id": None, + }, + ) + return + assert len(event.prev_state_events) > 0 + self.db_pool.simple_upsert_many_txn( + txn, + table="msc4242_state_dag_edges", + key_names=["room_id", "event_id", "prev_state_event_id"], + key_values=[ + (event.room_id, event.event_id, prev_state_event) + for prev_state_event in event.prev_state_events + ], + value_names=(), + value_values=(), + ) + def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: assert event.redacts is not None self.db_pool.simple_upsert_txn( @@ -3456,7 +3530,13 @@ class PersistEventsStore: """ state_groups = {} for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): + # state dag rooms allow outliers to have state, as `/get_missing_events` state dag events are nominally + # outliers (not present in the timeline) but do need state persisted so we can calculate + # what the auth_events are for the event. + if ( + not event.room_version.msc4242_state_dags + and event.internal_metadata.is_outlier() + ): # double-check that we don't have any events that claim to be outliers # *and* have partial state (which is meaningless: we should have no # state at all for an outlier) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index d55ea5cf7d..fe8079c201 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -71,6 +71,10 @@ purge_room_tables_with_room_id_column = ( # so must be deleted first. "sliding_sync_joined_rooms", "sliding_sync_membership_snapshots", + # Note: msc4242_state_dag_forward_extremities/edges have a foreign key to the `events` table + # so must be deleted first. + "msc4242_state_dag_forward_extremities", + "msc4242_state_dag_edges", "events", "federation_inbound_events_staging", "receipts_graph", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 8dbecc16e4..a0c42082f0 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -2277,6 +2277,132 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return True + async def get_user_report( + self, report_id: int + ) -> tuple[int, int, str, str, str] | None: + """Retrieve a user report + + Args: + report_id: ID of user report in database + Returns: + JSON dict of information from a user report or None if the + report does not exist. + """ + + return await self.db_pool.simple_select_one( + table="user_reports", + keyvalues={"id": report_id}, + retcols=("id", "received_ts", "target_user_id", "user_id", "reason"), + allow_none=True, + desc="get_user_report", + ) + + async def get_user_reports_paginate( + self, + start: int, + limit: int, + direction: Direction = Direction.BACKWARDS, + user_id: str | None = None, + target_user_id: str | None = None, + ) -> tuple[list[JsonDict], int]: + """Retrieve a paginated list of user reports + + Args: + start: event offset to begin the query from + limit: number of rows to retrieve + direction: Whether to fetch the most recent first (backwards) or the + oldest first (forwards) + user_id: search for user_id of the reporter. Ignored if user_id is None + target_user_id: search for user_id of the target. Ignored if target_user_id is None + Returns: + Tuple of: + json list of user reports + total number of user reports matching the filter criteria + """ + + def _get_user_reports_paginate_txn( + txn: LoggingTransaction, + ) -> tuple[list[dict[str, Any]], int]: + filters = [] + args: list[object] = [] + + if user_id: + filters.append("user_id LIKE ?") + args.extend(["%" + user_id + "%"]) + if target_user_id: + filters.append("target_user_id LIKE ?") + args.extend(["%" + target_user_id + "%"]) + + if direction == Direction.BACKWARDS: + order = "DESC" + else: + order = "ASC" + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + sql = f""" + SELECT COUNT(*) as total_user_reports + FROM user_reports {where_clause} + """ + txn.execute(sql, args) + count = cast(tuple[int], txn.fetchone())[0] + + sql = f""" + SELECT + id, + received_ts, + target_user_id, + user_id, + reason + FROM user_reports + {where_clause} + ORDER BY received_ts {order} + LIMIT ? + OFFSET ? + """ + + args += [limit, start] + txn.execute(sql, args) + + user_reports = [] + for row in txn: + user_reports.append( + { + "id": row[0], + "received_ts": row[1], + "target_user_id": row[2], + "user_id": row[3], + "reason": row[4], + } + ) + + return user_reports, count + + return await self.db_pool.runInteraction( + "get_user_reports_paginate", _get_user_reports_paginate_txn + ) + + async def delete_user_report(self, report_id: int) -> bool: + """Remove a user report from database. + + Args: + report_id: Report to delete + + Returns: + Whether the report was successfully deleted or not. + """ + try: + await self.db_pool.simple_delete_one( + table="user_reports", + keyvalues={"id": report_id}, + desc="delete_user_report", + ) + except StoreError: + # Deletion failed because report does not exist + return False + + return True + async def set_room_is_public(self, room_id: str, is_public: bool) -> None: await self.db_pool.simple_update_one( table="rooms", diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index cfde107b48..87523e6f18 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -38,7 +38,7 @@ import attr from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion -from synapse.events import EventBase +from synapse.events import EventBase, EventMetadata from synapse.events.snapshot import EventContext from synapse.logging.opentracing import trace from synapse.replication.tcp.streams import UnPartialStatedEventStream @@ -78,16 +78,6 @@ class Sentinel: ROOM_UNKNOWN_SENTINEL = Sentinel() -@attr.s(slots=True, frozen=True, auto_attribs=True) -class EventMetadata: - """Returned by `get_metadata_for_events`""" - - room_id: str - event_type: str - state_key: str | None - rejection_reason: str | None - - def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: v = KNOWN_ROOM_VERSIONS.get(room_version_id) if not v: diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index e3095a9d0d..1afc6d0b2a 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -174,6 +174,7 @@ Changes in SCHEMA_VERSION = 93 Changes in SCHEMA_VERSION = 94 - Add `recheck` column (boolean, default true) to the `redactions` table. + - MSC4242: Add state DAG tables. """ diff --git a/synapse/storage/schema/main/delta/94/03_device_lists_room_timestamp.sql b/synapse/storage/schema/main/delta/94/03_device_lists_room_timestamp.sql new file mode 100644 index 0000000000..b0ae1eaaf6 --- /dev/null +++ b/synapse/storage/schema/main/delta/94/03_device_lists_room_timestamp.sql @@ -0,0 +1,18 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 Element Creations, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +ALTER TABLE device_lists_changes_in_room ADD COLUMN inserted_ts BIGINT; + +-- Add a background update to add index +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (9403, 'device_lists_changes_in_room_inserted_ts_idx', '{}'); diff --git a/synapse/storage/schema/main/delta/94/03_state_dag_fwd_extrems.sql b/synapse/storage/schema/main/delta/94/03_state_dag_fwd_extrems.sql new file mode 100644 index 0000000000..bc5c738ba5 --- /dev/null +++ b/synapse/storage/schema/main/delta/94/03_state_dag_fwd_extrems.sql @@ -0,0 +1,38 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2026 Element Creations, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +CREATE TABLE IF NOT EXISTS msc4242_state_dag_forward_extremities( + -- we always expect the room to exist. If it gets removed, delete fwd extremities. + room_id TEXT NOT NULL REFERENCES rooms(room_id) ON DELETE CASCADE, + event_id TEXT NOT NULL REFERENCES events(event_id) ON DELETE CASCADE, + -- it doesn't make sense to reference the same event multiple times, and this uniqueness + -- index is also used to delete events once they are no longer forward extremities. + UNIQUE (event_id) +); +-- When creating events, we want to select all forward extremities for a room which this index helps with. +CREATE INDEX msc4242_state_dag_room ON msc4242_state_dag_forward_extremities(room_id); + + +CREATE TABLE IF NOT EXISTS msc4242_state_dag_edges( + -- Deleting the room deletes the state DAG. + room_id TEXT NOT NULL REFERENCES rooms(room_id) ON DELETE CASCADE, + -- the event IDs being referenced must exist (hence REFERENCES) and we do not want to accidentally delete + -- the event and create a hole in the state DAG. It is not possible for a state + -- DAG room to function with an holey DAG, so these events _cannot_ be purged. To purge them, the + -- entire room would need to be deleted. + event_id TEXT NOT NULL REFERENCES events(event_id), + -- one of the `prev_state_events` for this event ID. We must have it since we must have the entire state DAG. + -- can be NULL for the create event. + prev_state_event_id TEXT REFERENCES events(event_id) +); +CREATE UNIQUE INDEX msc4242_state_dag_edges_key ON msc4242_state_dag_edges(room_id, event_id, prev_state_event_id); diff --git a/synapse/storage/schema/main/delta/94/04_device_lists_changes_max_pruned.sql b/synapse/storage/schema/main/delta/94/04_device_lists_changes_max_pruned.sql new file mode 100644 index 0000000000..73836841da --- /dev/null +++ b/synapse/storage/schema/main/delta/94/04_device_lists_changes_max_pruned.sql @@ -0,0 +1,34 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2026 Element Creations Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + + +-- Tracks the maximum stream_id that has been deleted (pruned) from the +-- device_lists_changes_in_room table. This is used to determine whether it's +-- safe to read from that table for a given stream_id — if the requested +-- stream_id is < the value here, the data has been pruned and the table cannot +-- provide a complete answer. +-- +-- We need a separate table, rather than looking at the minimum stream_id in the +-- device_lists_changes_in_room table, because not all valid stream IDs will +-- have entries in the table. This could lead to situations where the minimum +-- stream ID was potentially much more recent than when we actually pruned. This +-- would cause us to incorrectly think that the table was not safe to read from, +-- when in fact it was. +CREATE TABLE IF NOT EXISTS device_lists_changes_in_room_max_pruned_stream_id ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, + stream_id BIGINT NOT NULL +); + +-- We assume that nothing has been deleted from the device_lists_changes_in_room +-- table, so we can set the initial value to 0. +INSERT INTO device_lists_changes_in_room_max_pruned_stream_id (stream_id) VALUES (0); diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index 29432bdd56..fe0ca04420 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -60,6 +60,9 @@ class EventInternalMetadata: device_id: str """The device ID of the user who sent this event, if any.""" + # MSC4242 state dags + calculated_auth_event_ids: list[str] + def get_dict(self) -> JsonDict: ... def is_outlier(self) -> bool: ... def copy(self) -> "EventInternalMetadata": ... diff --git a/synapse/synapse_rust/room_versions.pyi b/synapse/synapse_rust/room_versions.pyi index 909e3a1c26..9bbb538f18 100644 --- a/synapse/synapse_rust/room_versions.pyi +++ b/synapse/synapse_rust/room_versions.pyi @@ -31,6 +31,8 @@ class EventFormatVersions: """MSC1884-style format: introduced for room v4""" ROOM_V11_HYDRA_PLUS: int """MSC4291 room IDs as hashes: introduced for room HydraV11""" + ROOM_VMSC4242: int + """MSC4242 state DAGs: adds prev_state_events, removes auth_events""" KNOWN_EVENT_FORMAT_VERSIONS: frozenset[int] @@ -113,6 +115,14 @@ class RoomVersion: rather than in codepoints. If true, this room version uses stricter event size validation.""" + msc4242_state_dags: bool + """MSC4242: State DAGs. Creates events with prev_state_events instead of auth_events and derives + state from it. Events are always processed in causal order without any gaps in the DAG + (prev_state_events are always known), guaranteeing that processed events have a path to the + create event. This is an emergent property of state DAGs as asserting that there is a path + to the create event every time we insert an event would be prohibitively expensive. + This is similar to how doubly-linked lists can potentially not refer to previous items correctly + without verifying the list's integrity, but doing it on every insert is too expensive.""" class RoomVersions: V1: RoomVersion @@ -132,6 +142,7 @@ class RoomVersions: MSC3757v11: RoomVersion HydraV11: RoomVersion V12: RoomVersion + MSC4242v12: RoomVersion class KnownRoomVersionsMapping(Mapping[str, RoomVersion]): def add_room_version(self, room_version: RoomVersion) -> None: ... diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 275e5dfa1d..a40e0b0680 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -432,7 +432,14 @@ class UnstableGetExtremitiesTests(unittest.FederatingHomeserverTestCase): self.assertEqual(channel.json_body["error"], "Server is banned from room") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - @parameterized.expand([(k,) for k in KNOWN_ROOM_VERSIONS.keys()]) + # FIXME: Exclude MSC4242 room versions whilst it lacks federation support + @parameterized.expand( + [ + (k,) + for k in KNOWN_ROOM_VERSIONS.keys() + if k != RoomVersions.MSC4242v12.identifier + ] + ) @override_config( {"use_frozen_dicts": True, "experimental_features": {"msc4370_enabled": True}} ) @@ -440,7 +447,14 @@ class UnstableGetExtremitiesTests(unittest.FederatingHomeserverTestCase): """Test GET /extremities with USE_FROZEN_DICTS=True""" self._test_get_extremities_common(room_version) - @parameterized.expand([(k,) for k in KNOWN_ROOM_VERSIONS.keys()]) + # FIXME: Exclude MSC4242 room versions whilst it lacks federation support + @parameterized.expand( + [ + (k,) + for k in KNOWN_ROOM_VERSIONS.keys() + if k != RoomVersions.MSC4242v12.identifier + ] + ) @override_config( {"use_frozen_dicts": False, "experimental_features": {"msc4370_enabled": True}} ) @@ -573,12 +587,18 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): @override_config({"use_frozen_dicts": True}) def test_send_join_with_frozen_dicts(self, room_version: str) -> None: """Test send_join with USE_FROZEN_DICTS=True""" + if room_version == RoomVersions.MSC4242v12.identifier: + # TODO: This room version doesn't work over federation in this PR. + return self._test_send_join_common(room_version) @parameterized.expand([(k,) for k in KNOWN_ROOM_VERSIONS.keys()]) @override_config({"use_frozen_dicts": False}) def test_send_join_without_frozen_dicts(self, room_version: str) -> None: """Test send_join with USE_FROZEN_DICTS=False""" + if room_version == RoomVersions.MSC4242v12.identifier: + # TODO: This room version doesn't work over federation in this PR. + return self._test_send_join_common(room_version) def test_send_join_partial_state(self) -> None: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index f99e3cd4a2..9e44b1dc1e 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -21,20 +21,42 @@ # from unittest import mock +from unittest.mock import AsyncMock, Mock, patch +import signedjson.key +from parameterized import parameterized +from signedjson.types import SigningKey + +from twisted.internet import defer from twisted.internet.defer import ensureDeferred from twisted.internet.testing import MemoryReactor -from synapse.api.constants import RoomEncryptionAlgorithms +from synapse.api.constants import EventTypes, JoinRules, RoomEncryptionAlgorithms from synapse.api.errors import NotFoundError, SynapseError +from synapse.api.room_versions import RoomVersions from synapse.appservice import ApplicationService +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.events import EventBase, FrozenEventV3 +from synapse.federation.federation_client import SendJoinResult +from synapse.federation.transport.client import ( + StateRequestResponse, + TransportLayerClient, +) +from synapse.federation.units import Transaction from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceWriterHandler from synapse.rest import admin from synapse.rest.client import devices, login, register from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import ( + JsonDict, + StateMap, + UserID, + create_requester, + get_domain_from_id, +) from synapse.util.clock import Clock +from synapse.util.duration import Duration from synapse.util.task_scheduler import TaskScheduler from tests import unittest @@ -581,3 +603,334 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.assertTrue(len(res["next_batch"]) > 1) self.assertEqual(len(res["events"]), 1) self.assertEqual(res["events"][0]["content"]["body"], "foo") + + +@patch("synapse.crypto.keyring.Keyring.process_request", AsyncMock(return_value=None)) +class DeviceUnPartialStateTestCase(unittest.HomeserverTestCase): + """Tests that local device list changes during partial state are sent to + remote servers when the room un-partials.""" + + servlets = [ + admin.register_servlets, + login.register_servlets, + ] + + # The two remote servers to fake + REMOTE1_SERVER_NAME = "remote1" + REMOTE1_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test") + REMOTE1_USER = f"@user:{REMOTE1_SERVER_NAME}" + + REMOTE2_SERVER_NAME = "remote2" + REMOTE2_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test") + REMOTE2_USER = f"@user:{REMOTE2_SERVER_NAME}" + + def default_config(self) -> JsonDict: + config = super().default_config() + # Enable federation so that get_device_updates_by_remote works. + config["federation_sender_instances"] = ["master"] + return config + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + # Mock the federation transport client to prevent actual network calls. + self.federation_transport_client = AsyncMock(TransportLayerClient) + + self.federation_transport_client.send_transaction.return_value = {} + + hs = self.setup_test_homeserver( + federation_transport_client=self.federation_transport_client, + ) + handler = hs.get_device_handler() + assert isinstance(handler, DeviceWriterHandler) + self.device_handler = handler + self.store = hs.get_datastores().main + + return hs + + def _build_public_room(self) -> StateMap[EventBase]: + """Build a public room DAG that has REMOTE1 in it""" + + room_id = f"!room:{self.REMOTE1_SERVER_NAME}" + room_version = RoomVersions.V10 + + events: list[EventBase] = [] + + # First we make the create event + create_event_dict: JsonDict = { + "auth_events": [], + "content": { + "creator": self.REMOTE1_USER, + "room_version": room_version.identifier, + }, + "depth": 0, + "origin_server_ts": 0, + "prev_events": [], + "room_id": room_id, + "sender": self.REMOTE1_USER, + "state_key": "", + "type": EventTypes.Create, + } + + add_hashes_and_signatures( + room_version, + create_event_dict, + self.REMOTE1_SERVER_NAME, + self.REMOTE1_SERVER_SIGNATURE_KEY, + ) + + create_event = FrozenEventV3(create_event_dict, room_version, {}, None) + events.append(create_event) + + room_version = self.hs.config.server.default_room_version + join_event_dict: JsonDict = { + "auth_events": [ + create_event.event_id, + ], + "content": {"membership": "join"}, + "depth": 1, + "origin_server_ts": 100, + "prev_events": [create_event.event_id], + "sender": self.REMOTE1_USER, + "state_key": self.REMOTE1_USER, + "room_id": room_id, + "type": EventTypes.Member, + } + add_hashes_and_signatures( + room_version, + join_event_dict, + self.hs.hostname, + self.hs.signing_key, + ) + join_event = FrozenEventV3(join_event_dict, room_version, {}, None) + events.append(join_event) + + # Then set the join rules to public + join_rules_event_dict: JsonDict = { + "auth_events": [create_event.event_id, join_event.event_id], + "content": {"join_rule": JoinRules.PUBLIC}, + "depth": 2, + "origin_server_ts": 200, + "prev_events": [join_event.event_id], + "room_id": room_id, + "sender": self.REMOTE1_USER, + "state_key": "", + "type": EventTypes.JoinRules, + } + + add_hashes_and_signatures( + room_version, + join_rules_event_dict, + self.REMOTE1_SERVER_NAME, + self.REMOTE1_SERVER_SIGNATURE_KEY, + ) + join_rules_event = FrozenEventV3(join_rules_event_dict, room_version, {}, None) + events.append(join_rules_event) + + return {(event.type, event.state_key): event for event in events} + + def _build_signed_join_event( + self, + room_id: str, + user: str, + signing_key: SigningKey, + state: StateMap[EventBase], + ) -> FrozenEventV3: + """Build a join event for the local user, signed by the local server.""" + + latest_event = max(state.values(), key=lambda e: e.depth) + + room_version = self.hs.config.server.default_room_version + join_event_dict: JsonDict = { + "auth_events": [ + state[(EventTypes.Create, "")].event_id, + state[(EventTypes.JoinRules, "")].event_id, + ], + "content": {"membership": "join"}, + "depth": latest_event.depth + 1, + "origin_server_ts": latest_event.origin_server_ts + 100, + "prev_events": [latest_event.event_id], + "sender": user, + "state_key": user, + "room_id": room_id, + "type": EventTypes.Member, + } + add_hashes_and_signatures( + room_version, + join_event_dict, + get_domain_from_id(user), + signing_key, + ) + return FrozenEventV3(join_event_dict, room_version, {}, None) + + @parameterized.expand([("not_pruned", False), ("pruned", True)]) + @patch( + "synapse.storage.databases.main.devices.PRUNE_DEVICE_LISTS_CHANGES_IN_ROOM_AGE", + Duration(minutes=1), + ) + def test_local_device_changes_sent_to_new_servers_on_un_partial_state( + self, _test_suffix: str, prune_device_lists_change_in_room: bool + ) -> None: + """When a room un-partials, local device list changes made during the + partial state period should be sent to remote servers that were NOT + known at the time of the partial join. + + We do this by creating a room with one remote server, partialling + joining it, then receiving a join event from a second remote server. The + second remote server should receive a device list update EDU for any + local device changes that happened during the partial state period. + + We parameterize this test over whether during the unpartial process we + prune the `device_list_changes_in_room` table, to check that the + unpartial process correctly handles the case. + """ + + local_user = self.register_user("alice", "password") + self.login("alice", "password") + + # Build the remote room's state events. + room_state = self._build_public_room() + + # Before joining, we mock out the federation endpoints that are used + # during the unpartial process, so that we can control when the + # unpartial process completes. + get_room_state_ids_deferred: defer.Deferred[JsonDict] = defer.Deferred() + get_room_state_deferred: defer.Deferred[StateRequestResponse] = defer.Deferred() + self.federation_transport_client.get_room_state_ids = Mock( + side_effect=[get_room_state_ids_deferred] + ) + self.federation_transport_client.get_room_state = Mock( + side_effect=[get_room_state_deferred] + ) + + # Now make the local server partially join the room. + room_id = room_state[(EventTypes.Create, "")].room_id + room_version = room_state[(EventTypes.Create, "")].room_version + + local_join_event = self._build_signed_join_event( + room_id, local_user, self.hs.signing_key, room_state + ) + + # Mock the federation client endpoints for the partial join. + mock_make_membership_event = AsyncMock( + return_value=(self.REMOTE1_SERVER_NAME, local_join_event, room_version) + ) + mock_send_join = AsyncMock( + return_value=SendJoinResult( + local_join_event, + self.REMOTE1_SERVER_NAME, + state=list(room_state.values()), + auth_chain=list(room_state.values()), + partial_state=True, + # Only REMOTE1_SERVER_NAME is known at join time. + servers_in_room={self.REMOTE1_SERVER_NAME}, + ) + ) + + fed_handler = self.hs.get_federation_handler() + fed_client = self.hs.get_federation_client() + with ( + patch.object( + fed_client, "make_membership_event", mock_make_membership_event + ), + patch.object(fed_client, "send_join", mock_send_join), + ): + self.get_success( + fed_handler.do_invite_join( + [self.REMOTE1_SERVER_NAME], room_id, local_user, {} + ) + ) + + # The room should now be in partial state. + self.assertTrue(self.get_success(self.store.is_partial_state_room(room_id))) + + # A local device change happens while the room is in partial state. + self.get_success( + self.store.add_device_change_to_streams( + local_user, ["NEW_DEVICE"], [room_id] + ) + ) + + if prune_device_lists_change_in_room: + # Add a device change for another room, as we won't prune the most + # recent change. + self.get_success( + self.store.add_device_change_to_streams( + "@other:user", ["device1"], ["!some:room"] + ) + ) + + # Now prune the device list changes for the room. This simulates the + # case where the unpartial process prunes the + # `device_list_changes_in_room` table before processing the device + # list changes. + self.reactor.advance(120) # Advance past the pruning threshold + self.get_success(self.store._prune_device_lists_changes_in_room()) + + # Assert we actually pruned the device list changes for the room. + room_ids = self.get_success( + self.store.db_pool.simple_select_onecol( + table="device_lists_changes_in_room", + keyvalues={}, + retcol="room_id", + ) + ) + self.assertCountEqual(room_ids, ["!some:room"]) + + # Join the second server + new_state = dict(room_state) + new_state[(EventTypes.Member, local_user)] = local_join_event + join_event_2 = self._build_signed_join_event( + room_id, + self.REMOTE2_USER, + self.REMOTE2_SERVER_SIGNATURE_KEY, + new_state, + ) + + self.get_success( + self.hs.get_federation_event_handler().on_receive_pdu( + self.REMOTE2_SERVER_NAME, join_event_2 + ) + ) + + # Some EDUs may get sent out immediately, such as presence updates. + # However, we only care about the device list update EDU sent by the + # unpartialling process. Let's wait a few seconds and reset the mock. + self.reactor.advance(5) + self.federation_transport_client.send_transaction.reset_mock() + + # We now unblock the unpartial processs by returning the room state and + # state ids. This should trigger the device list update to be sent to + # REMOTE2_SERVER_NAME. + self.federation_transport_client.get_room_state_ids.assert_called_once_with( + self.REMOTE1_SERVER_NAME, + room_id, + event_id=local_join_event.prev_event_ids()[0], + ) + + get_room_state_ids_deferred.callback( + { + "pdu_ids": [event.event_id for event in room_state.values()], + "auth_event_ids": [], + } + ) + get_room_state_deferred.callback( + StateRequestResponse( + state=list(room_state.values()), + auth_events=[], + ) + ) + + # The device list EDU isn't necessarily sent out immediately + self.reactor.advance(30) + + # Check that only one transaction was sent, and that it contains the + # device list update EDU for the new device to REMOTE2_SERVER_NAME. + self.federation_transport_client.send_transaction.assert_called_once() + args, _ = self.federation_transport_client.send_transaction.call_args + transaction: Transaction = args[0] + + self.assertEqual(transaction.destination, self.REMOTE2_SERVER_NAME) + self.assertEqual(len(transaction.edus), 1) + + edu = transaction.edus[0] + self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["content"]["device_id"], "NEW_DEVICE") diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 3d856b9346..1aaa86e2e8 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -1121,7 +1121,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): return {"pdus": [missing_event.get_pdu_json()]} async def get_room_state_ids( - destination: str, room_id: str, event_id: str + destination: str, + room_id: str, + event_id: str, ) -> JsonDict: self.assertEqual(destination, self.OTHER_SERVER_NAME) self.assertEqual(event_id, missing_event.event_id) @@ -1131,7 +1133,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): } async def get_room_state( - room_version: RoomVersion, destination: str, room_id: str, event_id: str + room_version: RoomVersion, + destination: str, + room_id: str, + event_id: str, ) -> StateRequestResponse: self.assertEqual(destination, self.OTHER_SERVER_NAME) self.assertEqual(event_id, missing_event.event_id) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 12c8942bc8..f1b20a12ec 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -828,6 +828,24 @@ class ModuleApiTestCase(BaseModuleApiTestCase): # Ensure the pushers were deleted after the callback. self.assertEqual(len(self.hs.get_pusherpool().pushers[user_id].values()), 0) + def test_event_deprecated_methods(self) -> None: + """Test that deprecated methods on events are still functional.""" + user_id = self.register_user("user", "password") + tok = self.login("user", "password") + + room_id = self.helper.create_room_as(tok=tok) + + state = self.get_success( + self.hs.get_storage_controllers().state.get_current_state(room_id) + ) + create_event = state[(EventTypes.Create, "")] + + # `.user_id` is a deprecated alias for `.sender`. + self.assertEqual(create_event.user_id, user_id) + + # The event supports looking up keys via `__getitem__` although deprecated + self.assertEqual(create_event["room_id"], room_id) + class ModuleApiWorkerTestCase(BaseModuleApiTestCase, BaseMultiWorkerStreamTestCase): """For testing ModuleApi functionality in a multi-worker setup""" diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 77d824dcd8..b77a72ec4a 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -18,12 +18,15 @@ # [This file includes modifications made by New Vector Limited] # # +from __future__ import annotations import urllib.parse -from typing import cast +from typing import Any, cast +from unittest.mock import Mock from parameterized import parameterized +from twisted.internet.defer import Deferred from twisted.internet.testing import MemoryReactor from twisted.web.resource import Resource @@ -70,6 +73,24 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): resources["/_matrix/media"] = self.hs.get_media_repository_resource() return resources + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.fetches: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + + # A remote fetch of media that was not intentional. + # Used to check that remote media fetches do NOT happen. + def unexpected_remote_fetch(*args: Any, **kwargs: Any) -> Deferred[Any]: + self.fetches.append((args, kwargs)) + return Deferred() + + client = Mock() + client.federation_get_file = unexpected_remote_fetch + client.get_file = unexpected_remote_fetch + + return self.setup_test_homeserver( + clock=clock, + federation_http_client=client, + ) + def _ensure_quarantined( self, user_tok: str, @@ -176,6 +197,28 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ), ) + def test_non_admin_bypass_does_not_fetch_remote_media(self) -> None: + self.register_user("nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("nonadmin", "pass") + + channel = self.make_request( + "GET", + "/_matrix/client/v1/media/download/example.com/remote_media" + "?admin_unsafely_bypass_quarantine=true", + shorthand=False, + access_token=non_admin_user_tok, + await_result=False, + ) + self.pump() + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual( + channel.json_body["error"], + "Must be a server admin to bypass quarantine", + ) + # Check that a remote fetch attempt did not occur. + self.assertEqual(self.fetches, []) + @parameterized.expand( [ # Attempt quarantine media APIs as non-admin diff --git a/tests/rest/admin/test_user_reports.py b/tests/rest/admin/test_user_reports.py new file mode 100644 index 0000000000..94ae242d86 --- /dev/null +++ b/tests/rest/admin/test_user_reports.py @@ -0,0 +1,644 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2026 Element Creations Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + + +from twisted.internet.testing import MemoryReactor + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login, reporting, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util.clock import Clock + +from tests import unittest + + +class UserReportsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + reporting.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.users = {} + for i in range(10): + self.users[i] = self.register_user(f"user{i}", "pass") + + # users 1 and 2 report all other users + reporter_1_tok = self.login(self.users[0], "pass") + reporter_2_tok = self.login(self.users[1], "pass") + for num, user in self.users.items(): + if num <= 1: + continue + if num % 2 == 0: + self._report_user(user, reporter_1_tok) + else: + self._report_user(user, reporter_2_tok) + + self.url = "/_synapse/admin/v1/user_reports" + + def test_no_auth(self) -> None: + """ + Try to get a user report without authentication. + """ + channel = self.make_request("GET", self.url, {}) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self) -> None: + """ + If the user is not a server admin, an error 403 is returned. + """ + rando_tok = self.login(self.users[4], "pass") + channel = self.make_request( + "GET", + self.url, + access_token=rando_tok, + ) + + self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_default_success(self) -> None: + """ + Testing list of reported users + """ + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 8) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["user_reports"]) + + def test_limit(self) -> None: + """ + Testing list of reported users with limit + """ + + channel = self.make_request( + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 5) + self.assertEqual(channel.json_body["next_token"], 5) + self._check_fields(channel.json_body["user_reports"]) + + def test_from(self) -> None: + """ + Testing list of reported users with a defined starting point (from) + """ + + channel = self.make_request( + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 3) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["user_reports"]) + + def test_limit_and_from(self) -> None: + """ + Testing list of reported users with a defined starting point and limit + """ + + channel = self.make_request( + "GET", + self.url + "?from=2&limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(channel.json_body["next_token"], 7) + self.assertEqual(len(channel.json_body["user_reports"]), 5) + self._check_fields(channel.json_body["user_reports"]) + + def test_filter_by_target_user_id(self) -> None: + """ + Testing list of reported users with a filter of target_user_id + """ + + channel = self.make_request( + "GET", + self.url + "?target_user_id=%s" % self.users[3], + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 1) + self.assertEqual(len(channel.json_body["user_reports"]), 1) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["user_reports"]) + + for report in channel.json_body["user_reports"]: + self.assertEqual(report["target_user_id"], self.users[3]) + + def test_filter_user(self) -> None: + """ + Testing list of reported users with a filter of reporting user + """ + + channel = self.make_request( + "GET", + self.url + "?user_id=%s" % self.users[0], + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 4) + self.assertEqual(len(channel.json_body["user_reports"]), 4) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["user_reports"]) + + for report in channel.json_body["user_reports"]: + self.assertEqual(report["user_id"], self.users[0]) + + def test_filter_user_and_target_user(self) -> None: + """ + Testing list of reported users with a filter of reporting user and target_user + """ + + channel = self.make_request( + "GET", + self.url + "?user_id=%s&target_user_id=%s" % (self.users[1], self.users[7]), + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 1) + self.assertEqual(len(channel.json_body["user_reports"]), 1) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["user_reports"]) + + for report in channel.json_body["user_reports"]: + self.assertEqual(report["user_id"], self.users[1]) + self.assertEqual(report["target_user_id"], self.users[7]) + + def test_valid_search_order(self) -> None: + """ + Testing search order. Order by timestamps. + """ + + # fetch the most recent first, largest timestamp + channel = self.make_request( + "GET", + self.url + "?dir=b", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 8) + report = 1 + while report < len(channel.json_body["user_reports"]): + self.assertGreaterEqual( + channel.json_body["user_reports"][report - 1]["received_ts"], + channel.json_body["user_reports"][report]["received_ts"], + ) + report += 1 + + # fetch the oldest first, smallest timestamp + channel = self.make_request( + "GET", + self.url + "?dir=f", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 8) + report = 1 + while report < len(channel.json_body["user_reports"]): + self.assertLessEqual( + channel.json_body["user_reports"][report - 1]["received_ts"], + channel.json_body["user_reports"][report]["received_ts"], + ) + report += 1 + + def test_invalid_search_order(self) -> None: + """ + Testing that a invalid search order returns a 400 + """ + + channel = self.make_request( + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Query parameter 'dir' must be one of ['b', 'f']", + channel.json_body["error"], + ) + + def test_limit_is_negative(self) -> None: + """ + Testing that a negative limit parameter returns a 400 + """ + + channel = self.make_request( + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_from_is_negative(self) -> None: + """ + Testing that a negative from parameter returns a 400 + """ + + channel = self.make_request( + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_next_token(self) -> None: + """ + Testing that `next_token` appears at the right place + """ + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=8", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 8) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=10", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 8) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=6", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 6) + self.assertEqual(channel.json_body["next_token"], 6) + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", + self.url + "?from=6", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 2) + self.assertNotIn("next_token", channel.json_body) + + def _report_user(self, target_user: str, reporter_tok: str) -> None: + """Report a user""" + channel = self.make_request( + "POST", + "_matrix/client/v3/users/%s/report" % (target_user), + {"reason": "stone-cold bummer"}, + access_token=reporter_tok, + shorthand=False, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + def _check_fields(self, content: list[JsonDict]) -> None: + """Checks that all attributes are present in a user report""" + for c in content: + self.assertIn("id", c) + self.assertIn("received_ts", c) + self.assertIn("user_id", c) + self.assertIn("target_user_id", c) + self.assertIn("reason", c) + + +class UserReportDetailTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + reporting.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.bad_user = self.register_user("user", "pass") + self.bad_user_tok = self.login("user", "pass") + + channel = self.make_request( + "POST", + "_matrix/client/v3/users/%s/report" % (self.bad_user), + {"reason": "stone-cold bummer"}, + access_token=self.admin_user_tok, + shorthand=False, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # first created user report gets `id`=2 + self.url = "/_synapse/admin/v1/user_reports/2" + + def test_no_auth(self) -> None: + """ + Try to get user report without authentication. + """ + channel = self.make_request("GET", self.url, {}) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self) -> None: + """ + If the user is not a server admin, an error 403 is returned. + """ + + channel = self.make_request( + "GET", + self.url, + access_token=self.bad_user_tok, + ) + + self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_default_success(self) -> None: + """ + Testing get a reported user + """ + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self._check_fields(channel.json_body) + + def test_invalid_report_id(self) -> None: + """ + Testing that an invalid `report_id` returns a 400. + """ + + # `report_id` is negative + channel = self.make_request( + "GET", + "/_synapse/admin/v1/user_reports/-123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is a non-numerical string + channel = self.make_request( + "GET", + "/_synapse/admin/v1/user_reports/abcdef", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is undefined + channel = self.make_request( + "GET", + "/_synapse/admin/v1/user_reports/", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + def test_report_id_not_found(self) -> None: + """ + Testing that a not existing `report_id` returns a 404. + """ + + channel = self.make_request( + "GET", + "/_synapse/admin/v1/user_reports/123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual("User report not found", channel.json_body["error"]) + + def _check_fields(self, content: JsonDict) -> None: + """Checks that all attributes are present in a user report""" + self.assertIn("id", content) + self.assertIn("received_ts", content) + self.assertIn("target_user_id", content) + self.assertIn("user_id", content) + self.assertIn("reason", content) + + +class DeleteUserReportTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self._store = hs.get_datastores().main + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.bad_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # create report + self.get_success( + self._store.add_user_report( + self.bad_user, + self.admin_user, + "super bummer", + self.clock.time_msec(), + ) + ) + + self.url = "/_synapse/admin/v1/user_reports/2" + + def test_no_auth(self) -> None: + """ + Try to delete user report without authentication. + """ + channel = self.make_request("DELETE", self.url) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self) -> None: + """ + If the user is not a server admin, an error 403 is returned. + """ + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.other_user_tok, + ) + + self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_delete_success(self) -> None: + """ + Testing delete a report. + """ + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + # check that report was deleted + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_invalid_report_id(self) -> None: + """ + Testing that an invalid `report_id` returns a 400. + """ + + # `report_id` is negative + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/user_reports/-123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is a non-numerical string + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/user_reports/abcdef", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is undefined + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/user_reports/", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + def test_report_id_not_found(self) -> None: + """ + Testing that a not existing `report_id` returns a 404. + """ + + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/user_reports/123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual("User report not found", channel.json_body["error"]) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 1d1979e19f..b153c74980 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -19,15 +19,21 @@ # # +import itertools from typing import Collection +from unittest.mock import patch from twisted.internet.testing import MemoryReactor import synapse.api.errors from synapse.api.constants import EduTypes from synapse.server import HomeServer +from synapse.storage.databases.main.devices import ( + PRUNE_DEVICE_LISTS_CHANGES_IN_ROOM_AGE, +) from synapse.types import JsonDict from synapse.util.clock import Clock +from synapse.util.duration import Duration from tests.unittest import HomeserverTestCase @@ -351,3 +357,101 @@ class DeviceStoreTestCase(HomeserverTestCase): synapse.api.errors.StoreError, ) self.assertEqual(404, exc.value.code) + + @patch("synapse.storage.databases.main.devices.PRUNE_DEVICE_LISTS_BATCH_SIZE", 5) + def test_prune_old_device_lists_changes_in_room(self) -> None: + """Test that old entries in the `device_lists_changes_in_room` table are pruned properly.""" + + # Pretend the user is in a few rooms. + room_ids = [f"!room{i}:test" for i in range(20)] + + # Create a generator for device IDs so we can easily create many unique + # device IDs without having to keep track of the count ourselves. + device_id_gen = (f"device_id{i}" for i in itertools.count()) + + def get_devices_in_room_status() -> tuple[int, str]: + """Helper function to get the count of entries in + `device_lists_changes_in_room` and the minimum device_id.""" + return self.get_success( + self.store.db_pool.simple_select_one( + table="device_lists_changes_in_room", + keyvalues={}, + retcols=("COUNT(*)", "MIN(device_id)"), + ) + ) + + # First we add some initial entries to the `device_lists_changes_in_room`. + self.get_success( + self.store.add_device_change_to_streams( + user_id="@user_id:test", + device_ids=[next(device_id_gen) for _ in range(10)], + room_ids=room_ids, + ) + ) + + # Advance the reactor a while, but not long enough to trigger pruning. + self.reactor.advance(Duration(hours=1).as_secs()) + + # The `device_lists_changes_in_room` table should now have 10 * + # len(room_ids) entries, and the minimum device_id should be + # `device_id0`. + count, min_device_id = get_devices_in_room_status() + self.assertEqual(count, 10 * len(room_ids)) + self.assertEqual(min_device_id, "device_id0") + + # Record the max pruned stream ID before pruning, so we can check + # that this correctly updates after pruning. + starting_max_pruned_id = self.get_success( + self.store.db_pool.runInteraction( + "get_max_pruned_device_lists_changes_in_room", + self.store._get_max_pruned_device_lists_changes_in_room_txn, + ) + ) + + # Now we add some more entries. + self.get_success( + self.store.add_device_change_to_streams( + user_id="@user_id:test", + device_ids=[next(device_id_gen) for _ in range(10)], + room_ids=room_ids, + ) + ) + + # Advance the reactor a while more, so that the first batch of entries is + # now old enough to be pruned. + self.reactor.advance( + (PRUNE_DEVICE_LISTS_CHANGES_IN_ROOM_AGE - Duration(minutes=30)).as_secs() + ) + + # Advance repeatedly a bit so that the pruning process can run to completion. + for _ in range(10): + self.reactor.advance(Duration(milliseconds=110).as_secs()) + + # Check that the old entries have been pruned, and the new entries are still there. + count, min_device_id = get_devices_in_room_status() + self.assertEqual(count, 10 * len(room_ids)) + self.assertEqual(min_device_id, "device_id10") + + # We should always keep the most recent entries, even if they are old enough to be pruned. + self.reactor.advance( + (PRUNE_DEVICE_LISTS_CHANGES_IN_ROOM_AGE + Duration(minutes=30)).as_secs() + ) + + # Advance repeatedly a bit so that the pruning process can run to completion. + for _ in range(10): + self.reactor.advance(Duration(milliseconds=110).as_secs()) + + count, min_device_id = get_devices_in_room_status() + # We should always keep the most recent entries so that we can + # calculate the maximum stream ID used. + self.assertEqual(count, len(room_ids)) + self.assertEqual(min_device_id, "device_id19") + + # Check that the max pruned stream ID has been advanced after pruning. + max_pruned_id = self.get_success( + self.store.db_pool.runInteraction( + "get_max_pruned_device_lists_changes_in_room", + self.store._get_max_pruned_device_lists_changes_in_room_txn, + ) + ) + self.assertGreater(max_pruned_id, starting_max_pruned_id) diff --git a/tests/storage/test_msc4242_state_dag.py b/tests/storage/test_msc4242_state_dag.py new file mode 100644 index 0000000000..8775e5c8eb --- /dev/null +++ b/tests/storage/test_msc4242_state_dag.py @@ -0,0 +1,371 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2026 Element Creations, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . + +from typing import Iterable +from unittest.mock import Mock + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.events import FrozenEventVMSC4242, make_event_from_dict +from synapse.events.snapshot import EventContext +from synapse.rest.client import room +from synapse.server import HomeServer +from synapse.util.clock import Clock + +from tests.unittest import HomeserverTestCase, override_config + + +class MSC4242StateDagsTests(HomeserverTestCase): + user_id = "@user1:server" + servlets = [room.register_servlets] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + hs = self.setup_test_homeserver("server") + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.room_id = self.helper.create_room_as( + self.user_id, + room_version=RoomVersions.MSC4242v12.identifier, + ) + + self.store = hs.get_datastores().main + self._storage_controllers = self.hs.get_storage_controllers() + + def _get_prev_state_events(self, event_id: str) -> list[str]: + ev = self.helper.get_event(self.room_id, event_id) + prev_state_events: list[str] | None = ev.get("prev_state_events", None) + assert prev_state_events is not None + return prev_state_events + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_forward_extremities_are_calculated(self) -> None: + """ + Check that forward extremities are set as prev_state_events and that they don't change + for non-state events. + """ + # they don't change for messages + first_event_id = self.helper.send(self.room_id, body="test1")["event_id"] + first_prev_state_events = self._get_prev_state_events(first_event_id) + assert len(first_prev_state_events) == 1 + second_id = self.helper.send(self.room_id, body="test2")["event_id"] + second_prev_state_events = self._get_prev_state_events(second_id) + assert len(second_prev_state_events) == 1 + self.assertIncludes( + set(first_prev_state_events), set(second_prev_state_events), exact=True + ) + + # send an auth event, which should change the prev_state_events on *subsequent* events + join_rule_state_event_id = self.helper.send_state( + self.room_id, + EventTypes.JoinRules, + { + "join_rule": "knock", + }, + tok="nope", + )["event_id"] + join_rule_prev_state_event_ids = self._get_prev_state_events( + join_rule_state_event_id + ) + self.assertIncludes( + set(second_prev_state_events), + set(join_rule_prev_state_event_ids), + exact=True, + ) + + # prev_state_events should always point to the join rule now + third_event_id = self.helper.send(self.room_id, body="test3")["event_id"] + third_prev_state_events = self._get_prev_state_events(third_event_id) + self.assertIncludes( + set(third_prev_state_events), {join_rule_state_event_id}, exact=True + ) + # and non-auth state should also update prev_state_events + name_state_event_id = self.helper.send_state( + self.room_id, + EventTypes.Name, + { + "name": "State DAGs!", + }, + tok="nope", + )["event_id"] + name_prev_state_event_ids = self._get_prev_state_events(name_state_event_id) + self.assertIncludes( + set(name_prev_state_event_ids), {join_rule_state_event_id}, exact=True + ) + fourth_event_id = self.helper.send(self.room_id, body="test4")["event_id"] + fourth_prev_state_events = self._get_prev_state_events(fourth_event_id) + self.assertIncludes( + set(fourth_prev_state_events), {name_state_event_id}, exact=True + ) + + +class MSC4242EventPersistenceStateDagsStoreTestCase(HomeserverTestCase): + servlets = [ + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + persistence = hs.get_storage_controllers().persistence + assert persistence is not None + self.persistence = persistence + self.room_id = "!foo:bar" + self.seen_event_ids: set[str] = set() + self.persistence.main_store = Mock(spec=["have_seen_events"]) + self.persistence.main_store.have_seen_events.side_effect = ( + self._have_seen_events + ) + self.rejected_event_ids_and_their_prevs: set[str] = set() + self.persistence.persist_events_store = Mock( + spec=["_get_prevs_before_rejected"] + ) + self.persistence.persist_events_store._get_prevs_before_rejected.side_effect = ( + self._get_prevs_before_rejected + ) + + async def _have_seen_events( + self, room_id: str, event_ids: Iterable[str] + ) -> set[str]: + unknown_events = set(event_ids) + return self.seen_event_ids.intersection(unknown_events) + + async def _get_prevs_before_rejected( + self, event_ids: Iterable[str], include_soft_failed: bool = True + ) -> set[str]: + return self.rejected_event_ids_and_their_prevs + + def _make_event( + self, + id: str, + prev_state_events: list[str], + rejected: bool = False, + ) -> tuple[FrozenEventVMSC4242, EventContext]: + ev = make_event_from_dict( + { + "prev_state_events": prev_state_events, + "content": { + "membership": "join", + }, + "sender": "@unimportant:info", + "state_key": "@unimportant:info", + "type": "m.room.member", + "room_id": self.room_id, + }, + room_version=RoomVersions.MSC4242v12, + ) + assert isinstance(ev, FrozenEventVMSC4242) + ev._event_id = id + ctx = Mock() + ctx.rejected = rejected + return ev, ctx + + def _test( + self, + current_fwds: list[str], + new_events: list[tuple[FrozenEventVMSC4242, EventContext]], + want_new_extrems: set[str], + want_raises: bool = False, + ) -> None: + """ + Tests the logic of _calculate_new_state_dag_extremities. + + Tests that the new extremities calculated as a result of processing current_fwds and new_events + matches want_new_extrems or raises if want_raises is True. + """ + coroutine = self.persistence._calculate_new_state_dag_extremities( + self.room_id, + frozenset(current_fwds), + new_events, + ) + if want_raises: + f = self.get_failure(coroutine, SynapseError) + assert f is not None + return + + new_extrems = set(self.get_success(coroutine)) + self.assertIncludes( + new_extrems, + set(want_new_extrems), + exact=True, + message=f"want_new_extrems={want_new_extrems} got={new_extrems}", + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_simple(self) -> None: + # Simple linear chain + self._test( + current_fwds=[], + new_events=[ + self._make_event("$1", []), + self._make_event("$2", ["$1"]), + self._make_event("$3", ["$2"]), + self._make_event("$4", ["$3"]), + ], + want_new_extrems={"$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_fork(self) -> None: + # Simple fork so we end up with two forward extrems + self._test( + current_fwds=[], + new_events=[ + self._make_event("$1", []), + self._make_event("$2", ["$1"]), + self._make_event("$3", ["$2"]), + self._make_event("$4", ["$2"]), + ], + want_new_extrems={"$3", "$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_merge(self) -> None: + # Simple fork so we end up with two forward extrems + self._test( + current_fwds=[], + new_events=[ + self._make_event("$1", []), + self._make_event("$2", ["$1"]), + self._make_event("$3", ["$1"]), + self._make_event("$4", ["$2", "$3"]), + ], + want_new_extrems={"$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_fork_on_existing(self) -> None: + # Fork where we are adding to older events + self.seen_event_ids = {"$1", "$2", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3"]), # append to the forward extrem + self._make_event("$5", ["$1"]), # append to the root + ], + want_new_extrems={"$4", "$5"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_merge_on_existing(self) -> None: + # Merge where we are merging to older events + self.seen_event_ids = {"$1", "$2", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3", "$2"]), + ], + want_new_extrems={"$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_merge_on_not_current(self) -> None: + # Merge where we are merging to older events + self.seen_event_ids = {"$1", "$2", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$1", "$2"]), + ], + want_new_extrems={"$3", "$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_append_with_rejected(self) -> None: + # rejected events cannot be forward extremities + self.seen_event_ids = {"$1", "$2", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3"], rejected=True), + ], + want_new_extrems={"$3"}, + ) + + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3"], rejected=True), + self._make_event("$5", ["$4"], rejected=True), + ], + want_new_extrems={"$3"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_append_with_rejected_in_chain( + self, + ) -> None: + # rejected events cannot be forward extremities, but events that come after them can. + # this shouldn't cause multiple forward extremities. + self.seen_event_ids = {"$1", "$2", "$3"} + self.rejected_event_ids_and_their_prevs = {"$4", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3"], rejected=True), + self._make_event("$5", ["$4"]), + ], + want_new_extrems={"$5"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_missing_prevs_raises(self) -> None: + self._test( + current_fwds=[], + new_events=[ + self._make_event("$1", []), + self._make_event("$2", ["$1"]), + self._make_event("$3", ["$unknown"]), + self._make_event("$4", ["$3"]), + ], + want_new_extrems={"$4"}, + want_raises=True, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_complex(self) -> None: + """ + 1 + | \ + 2 4 + | + 3 + + Exists already, then becomes... + + 1______ + | \\ | + 2 4 5R + | | | + 3--7 6R + | \\ / \ + 10R 8 9 + + """ + # Merge where we are merging to older events + self.seen_event_ids = {"$1", "$2", "$3", "$4"} + self.rejected_event_ids_and_their_prevs = {"$1", "$5", "$6", "$3", "$10"} + self._test( + current_fwds=["$3", "$4"], + new_events=[ + self._make_event("$5", ["$1"], rejected=True), + self._make_event("$6", ["$5"], rejected=True), + self._make_event("$7", ["$4", "$3"]), + self._make_event("$8", ["$6", "$7"]), + self._make_event("$9", ["$6"]), + self._make_event("$10", ["$3"], rejected=True), + ], + want_new_extrems={"$8", "$9"}, + ) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 67a5c31c44..c346245706 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -232,6 +232,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): prev_event_ids: list[str], auth_event_ids: list[str] | None, depth: int | None = None, + prev_state_events: list[str] | None = None, ) -> EventBase: built_event = await self._base_builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 934a2fd307..9258f0d4dc 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -20,15 +20,16 @@ # import unittest +from collections import namedtuple from typing import Any, Collection, Iterable from parameterized import parameterized from synapse import event_auth -from synapse.api.constants import EventContentFields +from synapse.api.constants import EventContentFields, RejectedReason from synapse.api.errors import AuthError, SynapseError from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions -from synapse.events import EventBase, make_event_from_dict +from synapse.events import EventBase, event_exists_in_state_dag, make_event_from_dict from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import JsonDict, get_domain_from_id @@ -374,6 +375,195 @@ class EventAuthTestCase(unittest.TestCase): auth_events, ) + def test_msc4242_state_dag_rules(self) -> None: + """Tests additional rules in place for state DAG rooms. + + 1. m.room.create => if it has any prev_state_events, reject. + 2. Considering the event's prev_state_events: + i. If there are entries which do not belong in the same room, reject. + ii. If there are entries which do not have a state_key, reject. + iii. If there are entries which were themselves rejected under the checks performed on receipt of a PDU, reject. + """ + creator = "@creator:example.com" + room_version = RoomVersions.MSC4242v12 + + create_event = make_event_from_dict( + { + "type": "m.room.create", + "sender": creator, + "state_key": "", + "content": {"creator": creator}, + "prev_events": [], + "prev_state_events": [], + }, + room_version, + ) + create_event_2 = make_event_from_dict( + { + "type": "m.room.create", + "sender": creator, + "state_key": "", + "content": {"creator": creator, "another": "room"}, + "prev_events": [], + "prev_state_events": [], + }, + room_version, + ) + room_id = create_event.room_id + another_room_id = create_event_2.room_id + join_event = make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.member", + "sender": creator, + "state_key": creator, + "content": {"membership": "join"}, + "prev_events": [create_event.event_id], + "prev_state_events": [create_event.event_id], + }, + room_version, + {"calculated_auth_event_ids": [create_event.event_id]}, + ) + event_in_another_room = make_event_from_dict( + { + "room_id": another_room_id, + "type": "m.room.join_rules", + "sender": creator, + "state_key": "", + "content": {"join_rule": "public"}, + "prev_events": [join_event.event_id], + "prev_state_events": [join_event.event_id], + }, + room_version, + {"calculated_auth_event_ids": [create_event.event_id, join_event.event_id]}, + ) + msg_event = make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.message", + "sender": creator, + "content": {"msgtype": "m.text", "body": "I am a message"}, + "prev_events": [join_event.event_id], + "prev_state_events": [join_event.event_id], + }, + room_version, + {"calculated_auth_event_ids": [create_event.event_id, join_event.event_id]}, + ) + rejected_event = make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.name", + "sender": creator, + "state_key": "", + "content": {"name": "REJECTED"}, + "prev_events": [join_event.event_id], + "prev_state_events": [join_event.event_id], + }, + room_version, + {"calculated_auth_event_ids": [create_event.event_id, join_event.event_id]}, + rejected_reason=RejectedReason.AUTH_ERROR, + ) + RejectingTestCase = namedtuple( + "RejectingTestCase", "name events_in_store test_event" + ) + rejecting_test_cases = [ + RejectingTestCase( + name="create event has prev_state_events", + events_in_store=[], + test_event=make_event_from_dict( + { + "type": "m.room.create", + "sender": creator, + "state_key": "", + "content": {"creator": creator}, + "prev_events": [], + "prev_state_events": [create_event.event_id], + }, + room_version, + {}, + ), + ), + RejectingTestCase( + name="prev_state_event belongs in a different room", + events_in_store=[create_event, join_event, event_in_another_room], + test_event=make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.name", + "sender": creator, + "state_key": "", + "content": {"name": "prev_state_event is in another room"}, + "prev_events": [join_event.event_id], + "prev_state_events": [event_in_another_room.event_id], + }, + room_version, + { + "calculated_auth_event_ids": [ + create_event.event_id, + join_event.event_id, + ] + }, + ), + ), + RejectingTestCase( + name="prev_state_event is a message event", + events_in_store=[create_event, join_event, msg_event], + test_event=make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.name", + "sender": creator, + "state_key": "", + "content": {"name": "prev state event is a message"}, + "prev_events": [msg_event.event_id], + "prev_state_events": [msg_event.event_id], + }, + room_version, + { + "calculated_auth_event_ids": [ + create_event.event_id, + join_event.event_id, + ] + }, + ), + ), + RejectingTestCase( + name="prev_state_event was rejected", + events_in_store=[create_event, join_event, rejected_event], + test_event=make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.name", + "sender": creator, + "state_key": "", + "content": {"name": "prev state event was rejected"}, + "prev_events": [join_event.event_id], + "prev_state_events": [rejected_event.event_id], + }, + room_version, + { + "calculated_auth_event_ids": [ + create_event.event_id, + join_event.event_id, + ] + }, + ), + ), + ] + + for test_case in rejecting_test_cases: + event_store = _StubEventSourceStore() + event_store.add_events(test_case.events_in_store) + + with self.assertRaises( + AuthError, msg=f"test case {test_case.name} was not rejected" + ): + get_awaitable_result( + event_auth.check_state_independent_auth_rules( + event_store, test_case.test_event + ) + ) + @parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)]) def test_notifications( self, room_version: RoomVersion, allow_modification: bool @@ -769,6 +959,105 @@ class EventAuthTestCase(unittest.TestCase): with self.assertRaises(SynapseError): event_auth._check_power_levels(event.room_version, event, {}) + def test_event_exists_in_state_dag(self) -> None: + events_that_exist_in_state_dag = [ + { + "type": "m.room.create", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.join_rules", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.power_levels", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.server_acl", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.member", + "state_key": "@alice:somewhere", + "content": {}, + }, + { + "type": "m.room.third_party_invite", + "state_key": "flibble", + "content": {}, + }, + { + "type": "m.room.create", + "state_key": " ", + "content": {}, + }, + { + "type": "m.room.join_rules", + "state_key": " ", + "content": {}, + }, + { + "type": "m.room.power_levels", + "state_key": " ", + "content": {}, + }, + { + "type": "m.room.name", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.member", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.member", + "state_key": "hello_world", + "content": {}, + }, + ] + events_that_dont_exist_in_state_dag = [ + { + "type": "m.room.message", + "content": {}, + }, + { + "type": "m.room.create", + "content": {}, + }, + { + "type": "m.room.join_rules", + "content": {}, + }, + { + "type": "m.room.power_levels", + "content": {}, + }, + ] + + def check_events(events: list[dict], should_exist: bool) -> None: + for ev in events: + base = { + "room_id": TEST_ROOM_ID, + "sender": "@test:test.com", + "signatures": {"test.com": {"ed25519:0": "some9signature"}}, + } + base.update(ev) + event = make_event_from_dict(base, RoomVersions.V10) + got = event_exists_in_state_dag(event) + self.assertEqual( + got, should_exist, f"{ev} should_exist={should_exist} but got {got}" + ) + + check_events(events_that_exist_in_state_dag, should_exist=True) + check_events(events_that_dont_exist_in_state_dag, should_exist=False) + # helpers for making events TEST_DOMAIN = "example.com"