mirror of
https://github.com/element-hq/matrix-authentication-service.git
synced 2026-04-26 02:12:27 +00:00
syn2mas: Support migrating external IDs as upstream OAuth2 providers (#3917)
* Add `SynapseReader` support and test for external IDs * Run database migrations and do a config sync before syn2mas * FullUserId: implement Display * Add `MasWriter` support and test for upstream OAuth provider links * Remove special-purpose write buffers and use only the generic one * Build the provider ID mapping
This commit is contained in:
@@ -1,14 +1,19 @@
|
||||
use std::process::ExitCode;
|
||||
use std::{collections::HashMap, process::ExitCode};
|
||||
|
||||
use anyhow::Context;
|
||||
use camino::Utf8PathBuf;
|
||||
use clap::Parser;
|
||||
use figment::Figment;
|
||||
use mas_config::{ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig};
|
||||
use mas_config::{
|
||||
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, SyncConfig,
|
||||
UpstreamOAuth2Config,
|
||||
};
|
||||
use mas_storage::SystemClock;
|
||||
use mas_storage_pg::MIGRATOR;
|
||||
use rand::thread_rng;
|
||||
use sqlx::{postgres::PgConnectOptions, Connection, Either, PgConnection};
|
||||
use sqlx::{postgres::PgConnectOptions, types::Uuid, Connection, Either, PgConnection};
|
||||
use syn2mas::{synapse_config, LockedMasDatabase, MasWriter, SynapseReader};
|
||||
use tracing::{error, warn};
|
||||
use tracing::{error, info_span, warn, Instrument};
|
||||
|
||||
use crate::util::database_connection_from_config;
|
||||
|
||||
@@ -75,6 +80,7 @@ enum Subcommand {
|
||||
const NUM_WRITER_CONNECTIONS: usize = 8;
|
||||
|
||||
impl Options {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
|
||||
warn!("This version of the syn2mas tool is EXPERIMENTAL and INCOMPLETE. Do not use it, except for TESTING.");
|
||||
if !self.experimental_accepted {
|
||||
@@ -107,6 +113,35 @@ impl Options {
|
||||
|
||||
let mut mas_connection = database_connection_from_config(&config).await?;
|
||||
|
||||
MIGRATOR
|
||||
.run(&mut mas_connection)
|
||||
.instrument(info_span!("db.migrate"))
|
||||
.await
|
||||
.context("could not run migrations")?;
|
||||
|
||||
if matches!(&self.subcommand, Subcommand::Migrate { .. }) {
|
||||
// First perform a config sync
|
||||
// This is crucial to ensure we register upstream OAuth providers
|
||||
// in the MAS database
|
||||
//
|
||||
let config = SyncConfig::extract(figment)?;
|
||||
let clock = SystemClock::default();
|
||||
let encrypter = config.secrets.encrypter();
|
||||
|
||||
crate::sync::config_sync(
|
||||
config.upstream_oauth2,
|
||||
config.clients,
|
||||
&mut mas_connection,
|
||||
&encrypter,
|
||||
&clock,
|
||||
// Don't prune — we don't want to be unnecessarily destructive
|
||||
false,
|
||||
// Not a dry run — we do want to create the providers in the database
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let Either::Left(mut mas_connection) = LockedMasDatabase::try_new(&mut mas_connection)
|
||||
.await
|
||||
.context("failed to issue query to lock database")?
|
||||
@@ -166,6 +201,19 @@ impl Options {
|
||||
Ok(ExitCode::SUCCESS)
|
||||
}
|
||||
Subcommand::Migrate => {
|
||||
let provider_id_mappings: HashMap<String, Uuid> = {
|
||||
let mas_oauth2 = UpstreamOAuth2Config::extract_or_default(figment)?;
|
||||
|
||||
mas_oauth2
|
||||
.providers
|
||||
.iter()
|
||||
.filter_map(|provider| {
|
||||
let synapse_idp_id = provider.synapse_idp_id.clone()?;
|
||||
Some((synapse_idp_id, Uuid::from(provider.id)))
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
// TODO how should we handle warnings at this stage?
|
||||
|
||||
let mut reader = SynapseReader::new(&mut syn_conn, true).await?;
|
||||
@@ -181,8 +229,14 @@ impl Options {
|
||||
|
||||
// TODO progress reporting
|
||||
let mas_matrix = MatrixConfig::extract(figment)?;
|
||||
syn2mas::migrate(&mut reader, &mut writer, &mas_matrix.homeserver, &mut rng)
|
||||
.await?;
|
||||
syn2mas::migrate(
|
||||
&mut reader,
|
||||
&mut writer,
|
||||
&mas_matrix.homeserver,
|
||||
&mut rng,
|
||||
&provider_id_mappings,
|
||||
)
|
||||
.await?;
|
||||
|
||||
reader.finish().await?;
|
||||
writer.finish().await?;
|
||||
|
||||
18
crates/syn2mas/.sqlx/query-d79fd99ebed9033711f96113005096c848ae87c43b6430246ef3b6a1dc6a7a32.json
generated
Normal file
18
crates/syn2mas/.sqlx/query-d79fd99ebed9033711f96113005096c848ae87c43b6430246ef3b6a1dc6a7a32.json
generated
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO syn2mas__upstream_oauth_links\n (upstream_oauth_link_id, user_id, upstream_oauth_provider_id, subject, created_at)\n SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::UUID[], $4::TEXT[], $5::TIMESTAMP WITH TIME ZONE[])\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"UuidArray",
|
||||
"UuidArray",
|
||||
"UuidArray",
|
||||
"TextArray",
|
||||
"TimestamptzArray"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "d79fd99ebed9033711f96113005096c848ae87c43b6430246ef3b6a1dc6a7a32"
|
||||
}
|
||||
16
crates/syn2mas/src/mas_writer/fixtures/upstream_provider.sql
Normal file
16
crates/syn2mas/src/mas_writer/fixtures/upstream_provider.sql
Normal file
@@ -0,0 +1,16 @@
|
||||
INSERT INTO upstream_oauth_providers
|
||||
(
|
||||
upstream_oauth_provider_id,
|
||||
scope,
|
||||
client_id,
|
||||
token_endpoint_auth_method,
|
||||
created_at
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
'00000000-0000-0000-0000-000000000004',
|
||||
'openid',
|
||||
'someClientId',
|
||||
'client_secret_basic',
|
||||
'2011-12-13 14:15:16Z'
|
||||
);
|
||||
@@ -10,7 +10,7 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures_util::{future::BoxFuture, TryStreamExt};
|
||||
use futures_util::{future::BoxFuture, FutureExt, TryStreamExt};
|
||||
use sqlx::{query, query_as, Executor, PgConnection};
|
||||
use thiserror::Error;
|
||||
use thiserror_ext::{Construct, ContextInto};
|
||||
@@ -222,6 +222,14 @@ pub struct MasNewUnsupportedThreepid {
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
pub struct MasNewUpstreamOauthLink {
|
||||
pub link_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub upstream_provider_id: Uuid,
|
||||
pub subject: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// The 'version' of the password hashing scheme used for passwords when they
|
||||
/// are migrated from Synapse to MAS.
|
||||
/// This is version 1, as in the previous syn2mas script.
|
||||
@@ -234,6 +242,7 @@ pub const MAS_TABLES_AFFECTED_BY_MIGRATION: &[&str] = &[
|
||||
"user_passwords",
|
||||
"user_emails",
|
||||
"user_unsupported_third_party_ids",
|
||||
"upstream_oauth_links",
|
||||
];
|
||||
|
||||
/// Detect whether a syn2mas migration has started on the given database.
|
||||
@@ -522,7 +531,7 @@ impl<'conn> MasWriter<'conn> {
|
||||
/// - If the database writer connection pool had an error.
|
||||
#[allow(clippy::missing_panics_doc)] // not a real panic
|
||||
#[tracing::instrument(skip_all, level = Level::DEBUG)]
|
||||
pub async fn write_users(&mut self, users: Vec<MasNewUser>) -> Result<(), Error> {
|
||||
pub fn write_users(&mut self, users: Vec<MasNewUser>) -> BoxFuture<'_, Result<(), Error>> {
|
||||
self.writer_pool.spawn_with_connection(move |conn| Box::pin(async move {
|
||||
// `UNNEST` is a fast way to do bulk inserts, as it lets us send multiple rows in one statement
|
||||
// without having to change the statement SQL thus altering the query plan.
|
||||
@@ -568,7 +577,7 @@ impl<'conn> MasWriter<'conn> {
|
||||
).execute(&mut *conn).await.into_database("writing users to MAS")?;
|
||||
|
||||
Ok(())
|
||||
})).await
|
||||
})).boxed()
|
||||
}
|
||||
|
||||
/// Write a batch of user passwords to the database.
|
||||
@@ -580,14 +589,10 @@ impl<'conn> MasWriter<'conn> {
|
||||
/// - If the database writer connection pool had an error.
|
||||
#[allow(clippy::missing_panics_doc)] // not a real panic
|
||||
#[tracing::instrument(skip_all, level = Level::DEBUG)]
|
||||
pub async fn write_passwords(
|
||||
pub fn write_passwords(
|
||||
&mut self,
|
||||
passwords: Vec<MasNewUserPassword>,
|
||||
) -> Result<(), Error> {
|
||||
if passwords.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
) -> BoxFuture<'_, Result<(), Error>> {
|
||||
self.writer_pool.spawn_with_connection(move |conn| Box::pin(async move {
|
||||
let mut user_password_ids: Vec<Uuid> = Vec::with_capacity(passwords.len());
|
||||
let mut user_ids: Vec<Uuid> = Vec::with_capacity(passwords.len());
|
||||
@@ -622,17 +627,14 @@ impl<'conn> MasWriter<'conn> {
|
||||
).execute(&mut *conn).await.into_database("writing users to MAS")?;
|
||||
|
||||
Ok(())
|
||||
})).await
|
||||
})).boxed()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, level = Level::DEBUG)]
|
||||
pub async fn write_email_threepids(
|
||||
pub fn write_email_threepids(
|
||||
&mut self,
|
||||
threepids: Vec<MasNewEmailThreepid>,
|
||||
) -> Result<(), Error> {
|
||||
if threepids.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
) -> BoxFuture<'_, Result<(), Error>> {
|
||||
self.writer_pool.spawn_with_connection(move |conn| {
|
||||
Box::pin(async move {
|
||||
let mut user_email_ids: Vec<Uuid> = Vec::with_capacity(threepids.len());
|
||||
@@ -669,17 +671,14 @@ impl<'conn> MasWriter<'conn> {
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}).await
|
||||
}).boxed()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, level = Level::DEBUG)]
|
||||
pub async fn write_unsupported_threepids(
|
||||
pub fn write_unsupported_threepids(
|
||||
&mut self,
|
||||
threepids: Vec<MasNewUnsupportedThreepid>,
|
||||
) -> Result<(), Error> {
|
||||
if threepids.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
) -> BoxFuture<'_, Result<(), Error>> {
|
||||
self.writer_pool.spawn_with_connection(move |conn| {
|
||||
Box::pin(async move {
|
||||
let mut user_ids: Vec<Uuid> = Vec::with_capacity(threepids.len());
|
||||
@@ -700,8 +699,6 @@ impl<'conn> MasWriter<'conn> {
|
||||
created_ats.push(created_at);
|
||||
}
|
||||
|
||||
// `confirmed_at` is going to get removed in a future MAS release,
|
||||
// so just populate with `created_at`
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO syn2mas__user_unsupported_third_party_ids
|
||||
@@ -716,7 +713,53 @@ impl<'conn> MasWriter<'conn> {
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}).await
|
||||
}).boxed()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, level = Level::DEBUG)]
|
||||
pub fn write_upstream_oauth_links(
|
||||
&mut self,
|
||||
links: Vec<MasNewUpstreamOauthLink>,
|
||||
) -> BoxFuture<'_, Result<(), Error>> {
|
||||
self.writer_pool.spawn_with_connection(move |conn| {
|
||||
Box::pin(async move {
|
||||
let mut link_ids: Vec<Uuid> = Vec::with_capacity(links.len());
|
||||
let mut user_ids: Vec<Uuid> = Vec::with_capacity(links.len());
|
||||
let mut upstream_provider_ids: Vec<Uuid> = Vec::with_capacity(links.len());
|
||||
let mut subjects: Vec<String> = Vec::with_capacity(links.len());
|
||||
let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(links.len());
|
||||
|
||||
for MasNewUpstreamOauthLink {
|
||||
link_id,
|
||||
user_id,
|
||||
upstream_provider_id,
|
||||
subject,
|
||||
created_at,
|
||||
} in links
|
||||
{
|
||||
link_ids.push(link_id);
|
||||
user_ids.push(user_id);
|
||||
upstream_provider_ids.push(upstream_provider_id);
|
||||
subjects.push(subject);
|
||||
created_ats.push(created_at);
|
||||
}
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO syn2mas__upstream_oauth_links
|
||||
(upstream_oauth_link_id, user_id, upstream_oauth_provider_id, subject, created_at)
|
||||
SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::UUID[], $4::TEXT[], $5::TIMESTAMP WITH TIME ZONE[])
|
||||
"#,
|
||||
&link_ids[..],
|
||||
&user_ids[..],
|
||||
&upstream_provider_ids[..],
|
||||
&subjects[..],
|
||||
&created_ats[..],
|
||||
).execute(&mut *conn).await.into_database("writing unsupported threepids to MAS")?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -727,119 +770,60 @@ impl<'conn> MasWriter<'conn> {
|
||||
// stream to two tables at once...)
|
||||
const WRITE_BUFFER_BATCH_SIZE: usize = 4096;
|
||||
|
||||
pub struct MasUserWriteBuffer<'writer, 'conn> {
|
||||
users: Vec<MasNewUser>,
|
||||
passwords: Vec<MasNewUserPassword>,
|
||||
writer: &'writer mut MasWriter<'conn>,
|
||||
/// A function that can accept and flush buffers from a `MasWriteBuffer`.
|
||||
/// Intended uses are the methods on `MasWriter` such as `write_users`.
|
||||
type WriteBufferFlusher<'conn, T> =
|
||||
for<'a> fn(&'a mut MasWriter<'conn>, Vec<T>) -> BoxFuture<'a, Result<(), Error>>;
|
||||
|
||||
/// A buffer for writing rows to the MAS database.
|
||||
/// Generic over the type of rows.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if dropped before `finish()` has been called.
|
||||
pub struct MasWriteBuffer<'conn, T> {
|
||||
rows: Vec<T>,
|
||||
flusher: WriteBufferFlusher<'conn, T>,
|
||||
finished: bool,
|
||||
}
|
||||
|
||||
impl<'writer, 'conn> MasUserWriteBuffer<'writer, 'conn> {
|
||||
pub fn new(writer: &'writer mut MasWriter<'conn>) -> Self {
|
||||
MasUserWriteBuffer {
|
||||
users: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
|
||||
passwords: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
|
||||
writer,
|
||||
impl<'conn, T> MasWriteBuffer<'conn, T> {
|
||||
pub fn new(flusher: WriteBufferFlusher<'conn, T>) -> Self {
|
||||
MasWriteBuffer {
|
||||
rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
|
||||
flusher,
|
||||
finished: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn finish(mut self) -> Result<(), Error> {
|
||||
self.flush_users().await?;
|
||||
self.flush_passwords().await?;
|
||||
pub async fn finish(mut self, writer: &mut MasWriter<'conn>) -> Result<(), Error> {
|
||||
self.finished = true;
|
||||
self.flush(writer).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn flush_users(&mut self) -> Result<(), Error> {
|
||||
// via copy: 13s
|
||||
// not via copy: 14s
|
||||
// difference probably gets worse with latency
|
||||
self.writer
|
||||
.write_users(std::mem::take(&mut self.users))
|
||||
.await?;
|
||||
|
||||
self.users.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn flush_passwords(&mut self) -> Result<(), Error> {
|
||||
self.writer
|
||||
.write_passwords(std::mem::take(&mut self.passwords))
|
||||
.await?;
|
||||
self.passwords.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn write_user(&mut self, user: MasNewUser) -> Result<(), Error> {
|
||||
self.users.push(user);
|
||||
if self.users.len() >= WRITE_BUFFER_BATCH_SIZE {
|
||||
self.flush_users().await?;
|
||||
pub async fn flush(&mut self, writer: &mut MasWriter<'conn>) -> Result<(), Error> {
|
||||
if self.rows.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let rows = std::mem::take(&mut self.rows);
|
||||
self.rows.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
|
||||
(self.flusher)(writer, rows).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn write_password(&mut self, password: MasNewUserPassword) -> Result<(), Error> {
|
||||
self.passwords.push(password);
|
||||
if self.passwords.len() >= WRITE_BUFFER_BATCH_SIZE {
|
||||
self.flush_passwords().await?;
|
||||
pub async fn write(&mut self, writer: &mut MasWriter<'conn>, row: T) -> Result<(), Error> {
|
||||
self.rows.push(row);
|
||||
if self.rows.len() >= WRITE_BUFFER_BATCH_SIZE {
|
||||
self.flush(writer).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MasThreepidWriteBuffer<'writer, 'conn> {
|
||||
email: Vec<MasNewEmailThreepid>,
|
||||
unsupported: Vec<MasNewUnsupportedThreepid>,
|
||||
writer: &'writer mut MasWriter<'conn>,
|
||||
}
|
||||
|
||||
impl<'writer, 'conn> MasThreepidWriteBuffer<'writer, 'conn> {
|
||||
pub fn new(writer: &'writer mut MasWriter<'conn>) -> Self {
|
||||
MasThreepidWriteBuffer {
|
||||
email: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
|
||||
unsupported: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
|
||||
writer,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn finish(mut self) -> Result<(), Error> {
|
||||
self.flush_emails().await?;
|
||||
self.flush_unsupported().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn flush_emails(&mut self) -> Result<(), Error> {
|
||||
self.writer
|
||||
.write_email_threepids(std::mem::take(&mut self.email))
|
||||
.await?;
|
||||
self.email.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn flush_unsupported(&mut self) -> Result<(), Error> {
|
||||
self.writer
|
||||
.write_unsupported_threepids(std::mem::take(&mut self.unsupported))
|
||||
.await?;
|
||||
self.unsupported.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn write_email(&mut self, user: MasNewEmailThreepid) -> Result<(), Error> {
|
||||
self.email.push(user);
|
||||
if self.email.len() >= WRITE_BUFFER_BATCH_SIZE {
|
||||
self.flush_emails().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn write_password(
|
||||
&mut self,
|
||||
unsupported: MasNewUnsupportedThreepid,
|
||||
) -> Result<(), Error> {
|
||||
self.unsupported.push(unsupported);
|
||||
if self.unsupported.len() >= WRITE_BUFFER_BATCH_SIZE {
|
||||
self.flush_unsupported().await?;
|
||||
}
|
||||
Ok(())
|
||||
impl<T> Drop for MasWriteBuffer<'_, T> {
|
||||
fn drop(&mut self) {
|
||||
assert!(self.finished, "MasWriteBuffer dropped but not finished!");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -855,7 +839,8 @@ mod test {
|
||||
|
||||
use crate::{
|
||||
mas_writer::{
|
||||
MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUser, MasNewUserPassword,
|
||||
MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
|
||||
MasNewUserPassword,
|
||||
},
|
||||
LockedMasDatabase, MasWriter,
|
||||
};
|
||||
@@ -1085,4 +1070,39 @@ mod test {
|
||||
|
||||
assert_db_snapshot!(&mut conn);
|
||||
}
|
||||
|
||||
/// Tests writing a single user, with a link to an upstream provider.
|
||||
/// There needs to be an upstream provider in the database already — in the
|
||||
/// real migration, this is done by running a provider sync first.
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR", fixtures("upstream_provider"))]
|
||||
async fn test_write_user_with_upstream_provider_link(pool: PgPool) {
|
||||
let mut conn = pool.acquire().await.unwrap();
|
||||
let mut writer = make_mas_writer(&pool, &mut conn).await;
|
||||
|
||||
writer
|
||||
.write_users(vec![MasNewUser {
|
||||
user_id: Uuid::from_u128(1u128),
|
||||
username: "alice".to_owned(),
|
||||
created_at: DateTime::default(),
|
||||
locked_at: None,
|
||||
can_request_admin: false,
|
||||
}])
|
||||
.await
|
||||
.expect("failed to write user");
|
||||
|
||||
writer
|
||||
.write_upstream_oauth_links(vec![MasNewUpstreamOauthLink {
|
||||
user_id: Uuid::from_u128(1u128),
|
||||
link_id: Uuid::from_u128(3u128),
|
||||
upstream_provider_id: Uuid::from_u128(4u128),
|
||||
subject: "12345.67890".to_owned(),
|
||||
created_at: DateTime::default(),
|
||||
}])
|
||||
.await
|
||||
.expect("failed to write link");
|
||||
|
||||
writer.finish().await.expect("failed to finish MasWriter");
|
||||
|
||||
assert_db_snapshot!(&mut conn);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
---
|
||||
source: crates/syn2mas/src/mas_writer/mod.rs
|
||||
expression: db_snapshot
|
||||
---
|
||||
upstream_oauth_links:
|
||||
- created_at: "1970-01-01 00:00:00+00"
|
||||
human_account_name: ~
|
||||
subject: "12345.67890"
|
||||
upstream_oauth_link_id: 00000000-0000-0000-0000-000000000003
|
||||
upstream_oauth_provider_id: 00000000-0000-0000-0000-000000000004
|
||||
user_id: 00000000-0000-0000-0000-000000000001
|
||||
upstream_oauth_providers:
|
||||
- additional_parameters: ~
|
||||
authorization_endpoint_override: ~
|
||||
brand_name: ~
|
||||
claims_imports: "{}"
|
||||
client_id: someClientId
|
||||
created_at: "2011-12-13 14:15:16+00"
|
||||
disabled_at: ~
|
||||
discovery_mode: oidc
|
||||
encrypted_client_secret: ~
|
||||
fetch_userinfo: "false"
|
||||
human_name: ~
|
||||
id_token_signed_response_alg: RS256
|
||||
issuer: ~
|
||||
jwks_uri_override: ~
|
||||
pkce_mode: auto
|
||||
response_mode: query
|
||||
scope: openid
|
||||
token_endpoint_auth_method: client_secret_basic
|
||||
token_endpoint_override: ~
|
||||
token_endpoint_signing_alg: ~
|
||||
upstream_oauth_provider_id: 00000000-0000-0000-0000-000000000004
|
||||
userinfo_endpoint_override: ~
|
||||
userinfo_signed_response_alg: ~
|
||||
users:
|
||||
- can_request_admin: "false"
|
||||
created_at: "1970-01-01 00:00:00+00"
|
||||
locked_at: ~
|
||||
primary_user_email_id: ~
|
||||
user_id: 00000000-0000-0000-0000-000000000001
|
||||
username: alice
|
||||
@@ -12,3 +12,4 @@ ALTER TABLE syn2mas__users RENAME TO users;
|
||||
ALTER TABLE syn2mas__user_passwords RENAME TO user_passwords;
|
||||
ALTER TABLE syn2mas__user_emails RENAME TO user_emails;
|
||||
ALTER TABLE syn2mas__user_unsupported_third_party_ids RENAME TO user_unsupported_third_party_ids;
|
||||
ALTER TABLE syn2mas__upstream_oauth_links RENAME TO upstream_oauth_links;
|
||||
|
||||
@@ -41,3 +41,4 @@ ALTER TABLE users RENAME TO syn2mas__users;
|
||||
ALTER TABLE user_passwords RENAME TO syn2mas__user_passwords;
|
||||
ALTER TABLE user_emails RENAME TO syn2mas__user_emails;
|
||||
ALTER TABLE user_unsupported_third_party_ids RENAME TO syn2mas__user_unsupported_third_party_ids;
|
||||
ALTER TABLE upstream_oauth_links RENAME TO syn2mas__upstream_oauth_links;
|
||||
|
||||
@@ -25,10 +25,12 @@ use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
mas_writer::{
|
||||
self, MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUser, MasNewUserPassword,
|
||||
MasThreepidWriteBuffer, MasUserWriteBuffer, MasWriter,
|
||||
self, MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
|
||||
MasNewUserPassword, MasWriteBuffer, MasWriter,
|
||||
},
|
||||
synapse_reader::{
|
||||
self, ExtractLocalpartError, FullUserId, SynapseExternalId, SynapseThreepid, SynapseUser,
|
||||
},
|
||||
synapse_reader::{self, ExtractLocalpartError, FullUserId, SynapseThreepid, SynapseUser},
|
||||
SynapseReader,
|
||||
};
|
||||
|
||||
@@ -49,6 +51,16 @@ pub enum Error {
|
||||
source: ExtractLocalpartError,
|
||||
user: FullUserId,
|
||||
},
|
||||
#[error("user {user} was not found for migration but a row in {table} was found for them")]
|
||||
MissingUserFromDependentTable { table: String, user: FullUserId },
|
||||
#[error("missing a mapping for the auth provider with ID {synapse_id:?} (used by {user} and maybe other users)")]
|
||||
MissingAuthProviderMapping {
|
||||
/// `auth_provider` ID of the provider in Synapse, for which we have no
|
||||
/// mapping
|
||||
synapse_id: String,
|
||||
/// a user that is using this auth provider
|
||||
user: FullUserId,
|
||||
},
|
||||
}
|
||||
|
||||
struct UsersMigrated {
|
||||
@@ -68,11 +80,13 @@ struct UsersMigrated {
|
||||
///
|
||||
/// - An underlying database access error, either to MAS or to Synapse.
|
||||
/// - Invalid data in the Synapse database.
|
||||
#[allow(clippy::implicit_hasher)]
|
||||
pub async fn migrate(
|
||||
synapse: &mut SynapseReader<'_>,
|
||||
mas: &mut MasWriter<'_>,
|
||||
server_name: &str,
|
||||
rng: &mut impl RngCore,
|
||||
provider_id_mapping: &HashMap<String, Uuid>,
|
||||
) -> Result<(), Error> {
|
||||
let counts = synapse.count_rows().await.into_synapse("counting users")?;
|
||||
|
||||
@@ -97,6 +111,16 @@ pub async fn migrate(
|
||||
)
|
||||
.await?;
|
||||
|
||||
migrate_external_ids(
|
||||
synapse,
|
||||
mas,
|
||||
server_name,
|
||||
rng,
|
||||
&migrated_users.user_localparts_to_uuid,
|
||||
provider_id_mapping,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -108,7 +132,8 @@ async fn migrate_users(
|
||||
server_name: &str,
|
||||
rng: &mut impl RngCore,
|
||||
) -> Result<UsersMigrated, Error> {
|
||||
let mut write_buffer = MasUserWriteBuffer::new(mas);
|
||||
let mut user_buffer = MasWriteBuffer::new(MasWriter::write_users);
|
||||
let mut password_buffer = MasWriteBuffer::new(MasWriter::write_passwords);
|
||||
let mut users_stream = pin!(synapse.read_users());
|
||||
// TODO is 1:1 capacity enough for a hashmap?
|
||||
let mut user_localparts_to_uuid = HashMap::with_capacity(user_count_hint);
|
||||
@@ -119,23 +144,24 @@ async fn migrate_users(
|
||||
|
||||
user_localparts_to_uuid.insert(CompactString::new(&mas_user.username), mas_user.user_id);
|
||||
|
||||
write_buffer
|
||||
.write_user(mas_user)
|
||||
user_buffer
|
||||
.write(mas, mas_user)
|
||||
.await
|
||||
.into_mas("writing user")?;
|
||||
|
||||
if let Some(mas_password) = mas_password_opt {
|
||||
write_buffer
|
||||
.write_password(mas_password)
|
||||
password_buffer
|
||||
.write(mas, mas_password)
|
||||
.await
|
||||
.into_mas("writing password")?;
|
||||
}
|
||||
}
|
||||
|
||||
write_buffer
|
||||
.finish()
|
||||
user_buffer.finish(mas).await.into_mas("writing users")?;
|
||||
password_buffer
|
||||
.finish(mas)
|
||||
.await
|
||||
.into_mas("writing users & passwords")?;
|
||||
.into_mas("writing passwords")?;
|
||||
|
||||
Ok(UsersMigrated {
|
||||
user_localparts_to_uuid,
|
||||
@@ -150,7 +176,8 @@ async fn migrate_threepids(
|
||||
rng: &mut impl RngCore,
|
||||
user_localparts_to_uuid: &HashMap<CompactString, Uuid>,
|
||||
) -> Result<(), Error> {
|
||||
let mut write_buffer = MasThreepidWriteBuffer::new(mas);
|
||||
let mut email_buffer = MasWriteBuffer::new(MasWriter::write_email_threepids);
|
||||
let mut unsupported_buffer = MasWriteBuffer::new(MasWriter::write_unsupported_threepids);
|
||||
let mut users_stream = pin!(synapse.read_threepids());
|
||||
|
||||
while let Some(threepid_res) = users_stream.next().await {
|
||||
@@ -167,36 +194,121 @@ async fn migrate_threepids(
|
||||
.into_extract_localpart(synapse_user_id.clone())?
|
||||
.to_owned();
|
||||
let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else {
|
||||
todo!()
|
||||
return Err(Error::MissingUserFromDependentTable {
|
||||
table: "user_threepids".to_owned(),
|
||||
user: synapse_user_id,
|
||||
});
|
||||
};
|
||||
|
||||
if medium == "email" {
|
||||
write_buffer
|
||||
.write_email(MasNewEmailThreepid {
|
||||
user_id,
|
||||
user_email_id: Uuid::from(Ulid::from_datetime_with_source(
|
||||
created_at.into(),
|
||||
rng,
|
||||
)),
|
||||
email: address,
|
||||
created_at,
|
||||
})
|
||||
email_buffer
|
||||
.write(
|
||||
mas,
|
||||
MasNewEmailThreepid {
|
||||
user_id,
|
||||
user_email_id: Uuid::from(Ulid::from_datetime_with_source(
|
||||
created_at.into(),
|
||||
rng,
|
||||
)),
|
||||
email: address,
|
||||
created_at,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.into_mas("writing email")?;
|
||||
} else {
|
||||
write_buffer
|
||||
.write_password(MasNewUnsupportedThreepid {
|
||||
user_id,
|
||||
medium,
|
||||
address,
|
||||
created_at,
|
||||
})
|
||||
unsupported_buffer
|
||||
.write(
|
||||
mas,
|
||||
MasNewUnsupportedThreepid {
|
||||
user_id,
|
||||
medium,
|
||||
address,
|
||||
created_at,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.into_mas("writing unsupported threepid")?;
|
||||
}
|
||||
}
|
||||
|
||||
write_buffer.finish().await.into_mas("writing threepids")?;
|
||||
email_buffer
|
||||
.finish(mas)
|
||||
.await
|
||||
.into_mas("writing email threepids")?;
|
||||
unsupported_buffer
|
||||
.finish(mas)
|
||||
.await
|
||||
.into_mas("writing unsupported threepids")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `provider_id_mapping`: mapping from Synapse `auth_provider` ID to UUID of
|
||||
/// the upstream provider in MAS.
|
||||
#[tracing::instrument(skip_all, level = Level::INFO)]
|
||||
async fn migrate_external_ids(
|
||||
synapse: &mut SynapseReader<'_>,
|
||||
mas: &mut MasWriter<'_>,
|
||||
server_name: &str,
|
||||
rng: &mut impl RngCore,
|
||||
user_localparts_to_uuid: &HashMap<CompactString, Uuid>,
|
||||
provider_id_mapping: &HashMap<String, Uuid>,
|
||||
) -> Result<(), Error> {
|
||||
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_upstream_oauth_links);
|
||||
let mut extids_stream = pin!(synapse.read_user_external_ids());
|
||||
|
||||
while let Some(extid_res) = extids_stream.next().await {
|
||||
let SynapseExternalId {
|
||||
user_id: synapse_user_id,
|
||||
auth_provider,
|
||||
external_id: subject,
|
||||
} = extid_res.into_synapse("reading external ID")?;
|
||||
let username = synapse_user_id
|
||||
.extract_localpart(server_name)
|
||||
.into_extract_localpart(synapse_user_id.clone())?
|
||||
.to_owned();
|
||||
let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else {
|
||||
return Err(Error::MissingUserFromDependentTable {
|
||||
table: "user_external_ids".to_owned(),
|
||||
user: synapse_user_id,
|
||||
});
|
||||
};
|
||||
|
||||
let Some(&upstream_provider_id) = provider_id_mapping.get(&auth_provider) else {
|
||||
return Err(Error::MissingAuthProviderMapping {
|
||||
synapse_id: auth_provider,
|
||||
user: synapse_user_id,
|
||||
});
|
||||
};
|
||||
|
||||
// To save having to store user creation times, extract it from the ULID
|
||||
// This gives millisecond precision — good enough.
|
||||
let user_created_ts = Ulid::from(user_id).datetime();
|
||||
|
||||
let link_id: Uuid = Ulid::from_datetime_with_source(user_created_ts, rng).into();
|
||||
|
||||
write_buffer
|
||||
.write(
|
||||
mas,
|
||||
MasNewUpstreamOauthLink {
|
||||
link_id,
|
||||
user_id,
|
||||
upstream_provider_id,
|
||||
subject,
|
||||
created_at: user_created_ts.into(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.into_mas("failed to write upstream link")?;
|
||||
}
|
||||
|
||||
write_buffer
|
||||
.finish(mas)
|
||||
.await
|
||||
.into_mas("writing threepids")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
INSERT INTO user_external_ids
|
||||
(
|
||||
user_id,
|
||||
auth_provider,
|
||||
external_id
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
'@alice:example.com',
|
||||
'oidc-raasu',
|
||||
'871.syn30'
|
||||
);
|
||||
@@ -8,6 +8,8 @@
|
||||
//! This module provides facilities for streaming relevant types of database
|
||||
//! records from a Synapse database.
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures_util::{Stream, TryStreamExt};
|
||||
use sqlx::{query, Acquire, FromRow, PgConnection, Postgres, Row, Transaction, Type};
|
||||
@@ -30,6 +32,12 @@ pub enum Error {
|
||||
#[derive(Clone, Debug, sqlx::Decode, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct FullUserId(pub String);
|
||||
|
||||
impl Display for FullUserId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<Postgres> for FullUserId {
|
||||
fn type_info() -> <sqlx::Postgres as sqlx::Database>::TypeInfo {
|
||||
<String as Type<Postgres>>::type_info()
|
||||
@@ -193,13 +201,21 @@ pub struct SynapseThreepid {
|
||||
pub added_at: MillisecondsTimestamp,
|
||||
}
|
||||
|
||||
/// Row of the `user_external_ids` table in Synapse.
|
||||
#[derive(Clone, Debug, FromRow, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct SynapseExternalId {
|
||||
pub user_id: FullUserId,
|
||||
pub auth_provider: String,
|
||||
pub external_id: String,
|
||||
}
|
||||
|
||||
/// List of Synapse tables that we should acquire an `EXCLUSIVE` lock on.
|
||||
///
|
||||
/// This is a safety measure against other processes changing the data
|
||||
/// underneath our feet. It's still not a good idea to run Synapse at the same
|
||||
/// time as the migration.
|
||||
// TODO not complete!
|
||||
const TABLES_TO_LOCK: &[&str] = &["users"];
|
||||
const TABLES_TO_LOCK: &[&str] = &["users", "user_threepids", "user_external_ids"];
|
||||
|
||||
/// Number of migratable rows in various Synapse tables.
|
||||
/// Used to estimate progress.
|
||||
@@ -319,6 +335,21 @@ impl<'conn> SynapseReader<'conn> {
|
||||
.fetch(&mut *self.txn)
|
||||
.map_err(|err| err.into_database("reading Synapse threepids"))
|
||||
}
|
||||
|
||||
/// Read associations between Synapse users and external identity providers
|
||||
pub fn read_user_external_ids(
|
||||
&mut self,
|
||||
) -> impl Stream<Item = Result<SynapseExternalId, Error>> + '_ {
|
||||
sqlx::query_as(
|
||||
"
|
||||
SELECT
|
||||
user_id, auth_provider, external_id
|
||||
FROM user_external_ids
|
||||
",
|
||||
)
|
||||
.fetch(&mut *self.txn)
|
||||
.map_err(|err| err.into_database("reading Synapse user external IDs"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -330,7 +361,7 @@ mod test {
|
||||
use sqlx::{migrate::Migrator, PgPool};
|
||||
|
||||
use crate::{
|
||||
synapse_reader::{SynapseThreepid, SynapseUser},
|
||||
synapse_reader::{SynapseExternalId, SynapseThreepid, SynapseUser},
|
||||
SynapseReader,
|
||||
};
|
||||
|
||||
@@ -368,4 +399,20 @@ mod test {
|
||||
|
||||
assert_debug_snapshot!(threepids);
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "MIGRATOR", fixtures("user_alice", "external_ids_alice"))]
|
||||
async fn test_read_external_ids(pool: PgPool) {
|
||||
let mut conn = pool.acquire().await.expect("failed to get connection");
|
||||
let mut reader = SynapseReader::new(&mut conn, false)
|
||||
.await
|
||||
.expect("failed to make SynapseReader");
|
||||
|
||||
let external_ids: BTreeSet<SynapseExternalId> = reader
|
||||
.read_user_external_ids()
|
||||
.try_collect()
|
||||
.await
|
||||
.expect("failed to read Synapse external user IDs");
|
||||
|
||||
assert_debug_snapshot!(external_ids);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
---
|
||||
source: crates/syn2mas/src/synapse_reader/mod.rs
|
||||
expression: external_ids
|
||||
---
|
||||
{
|
||||
SynapseExternalId {
|
||||
user_id: FullUserId(
|
||||
"@alice:example.com",
|
||||
),
|
||||
auth_provider: "oidc-raasu",
|
||||
external_id: "871.syn30",
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
-- Copyright 2025 New Vector Ltd.
|
||||
--
|
||||
-- SPDX-License-Identifier: AGPL-3.0-only
|
||||
-- Please see LICENSE in the repository root for full details.
|
||||
|
||||
-- Brings in the `user_external_ids` table from Synapse
|
||||
|
||||
CREATE TABLE user_external_ids (
|
||||
auth_provider text NOT NULL,
|
||||
external_id text NOT NULL,
|
||||
user_id text NOT NULL
|
||||
);
|
||||
Reference in New Issue
Block a user