syn2mas: Support migrating external IDs as upstream OAuth2 providers (#3917)

* Add `SynapseReader` support and test for external IDs

* Run database migrations and do a config sync before syn2mas

* FullUserId: implement Display

* Add `MasWriter` support and test for upstream OAuth provider links

* Remove special-purpose write buffers and use only the generic one

* Build the provider ID mapping
This commit is contained in:
reivilibre
2025-01-30 10:34:20 +00:00
committed by GitHub
parent de597da468
commit fec4efd9d8
12 changed files with 508 additions and 160 deletions

View File

@@ -1,14 +1,19 @@
use std::process::ExitCode;
use std::{collections::HashMap, process::ExitCode};
use anyhow::Context;
use camino::Utf8PathBuf;
use clap::Parser;
use figment::Figment;
use mas_config::{ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig};
use mas_config::{
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, SyncConfig,
UpstreamOAuth2Config,
};
use mas_storage::SystemClock;
use mas_storage_pg::MIGRATOR;
use rand::thread_rng;
use sqlx::{postgres::PgConnectOptions, Connection, Either, PgConnection};
use sqlx::{postgres::PgConnectOptions, types::Uuid, Connection, Either, PgConnection};
use syn2mas::{synapse_config, LockedMasDatabase, MasWriter, SynapseReader};
use tracing::{error, warn};
use tracing::{error, info_span, warn, Instrument};
use crate::util::database_connection_from_config;
@@ -75,6 +80,7 @@ enum Subcommand {
const NUM_WRITER_CONNECTIONS: usize = 8;
impl Options {
#[allow(clippy::too_many_lines)]
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
warn!("This version of the syn2mas tool is EXPERIMENTAL and INCOMPLETE. Do not use it, except for TESTING.");
if !self.experimental_accepted {
@@ -107,6 +113,35 @@ impl Options {
let mut mas_connection = database_connection_from_config(&config).await?;
MIGRATOR
.run(&mut mas_connection)
.instrument(info_span!("db.migrate"))
.await
.context("could not run migrations")?;
if matches!(&self.subcommand, Subcommand::Migrate { .. }) {
// First perform a config sync
// This is crucial to ensure we register upstream OAuth providers
// in the MAS database
//
let config = SyncConfig::extract(figment)?;
let clock = SystemClock::default();
let encrypter = config.secrets.encrypter();
crate::sync::config_sync(
config.upstream_oauth2,
config.clients,
&mut mas_connection,
&encrypter,
&clock,
// Don't prune — we don't want to be unnecessarily destructive
false,
// Not a dry run — we do want to create the providers in the database
false,
)
.await?;
}
let Either::Left(mut mas_connection) = LockedMasDatabase::try_new(&mut mas_connection)
.await
.context("failed to issue query to lock database")?
@@ -166,6 +201,19 @@ impl Options {
Ok(ExitCode::SUCCESS)
}
Subcommand::Migrate => {
let provider_id_mappings: HashMap<String, Uuid> = {
let mas_oauth2 = UpstreamOAuth2Config::extract_or_default(figment)?;
mas_oauth2
.providers
.iter()
.filter_map(|provider| {
let synapse_idp_id = provider.synapse_idp_id.clone()?;
Some((synapse_idp_id, Uuid::from(provider.id)))
})
.collect()
};
// TODO how should we handle warnings at this stage?
let mut reader = SynapseReader::new(&mut syn_conn, true).await?;
@@ -181,8 +229,14 @@ impl Options {
// TODO progress reporting
let mas_matrix = MatrixConfig::extract(figment)?;
syn2mas::migrate(&mut reader, &mut writer, &mas_matrix.homeserver, &mut rng)
.await?;
syn2mas::migrate(
&mut reader,
&mut writer,
&mas_matrix.homeserver,
&mut rng,
&provider_id_mappings,
)
.await?;
reader.finish().await?;
writer.finish().await?;

View File

@@ -0,0 +1,18 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO syn2mas__upstream_oauth_links\n (upstream_oauth_link_id, user_id, upstream_oauth_provider_id, subject, created_at)\n SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::UUID[], $4::TEXT[], $5::TIMESTAMP WITH TIME ZONE[])\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"UuidArray",
"UuidArray",
"UuidArray",
"TextArray",
"TimestamptzArray"
]
},
"nullable": []
},
"hash": "d79fd99ebed9033711f96113005096c848ae87c43b6430246ef3b6a1dc6a7a32"
}

View File

@@ -0,0 +1,16 @@
INSERT INTO upstream_oauth_providers
(
upstream_oauth_provider_id,
scope,
client_id,
token_endpoint_auth_method,
created_at
)
VALUES
(
'00000000-0000-0000-0000-000000000004',
'openid',
'someClientId',
'client_secret_basic',
'2011-12-13 14:15:16Z'
);

View File

@@ -10,7 +10,7 @@
use std::fmt::Display;
use chrono::{DateTime, Utc};
use futures_util::{future::BoxFuture, TryStreamExt};
use futures_util::{future::BoxFuture, FutureExt, TryStreamExt};
use sqlx::{query, query_as, Executor, PgConnection};
use thiserror::Error;
use thiserror_ext::{Construct, ContextInto};
@@ -222,6 +222,14 @@ pub struct MasNewUnsupportedThreepid {
pub created_at: DateTime<Utc>,
}
pub struct MasNewUpstreamOauthLink {
pub link_id: Uuid,
pub user_id: Uuid,
pub upstream_provider_id: Uuid,
pub subject: String,
pub created_at: DateTime<Utc>,
}
/// The 'version' of the password hashing scheme used for passwords when they
/// are migrated from Synapse to MAS.
/// This is version 1, as in the previous syn2mas script.
@@ -234,6 +242,7 @@ pub const MAS_TABLES_AFFECTED_BY_MIGRATION: &[&str] = &[
"user_passwords",
"user_emails",
"user_unsupported_third_party_ids",
"upstream_oauth_links",
];
/// Detect whether a syn2mas migration has started on the given database.
@@ -522,7 +531,7 @@ impl<'conn> MasWriter<'conn> {
/// - If the database writer connection pool had an error.
#[allow(clippy::missing_panics_doc)] // not a real panic
#[tracing::instrument(skip_all, level = Level::DEBUG)]
pub async fn write_users(&mut self, users: Vec<MasNewUser>) -> Result<(), Error> {
pub fn write_users(&mut self, users: Vec<MasNewUser>) -> BoxFuture<'_, Result<(), Error>> {
self.writer_pool.spawn_with_connection(move |conn| Box::pin(async move {
// `UNNEST` is a fast way to do bulk inserts, as it lets us send multiple rows in one statement
// without having to change the statement SQL thus altering the query plan.
@@ -568,7 +577,7 @@ impl<'conn> MasWriter<'conn> {
).execute(&mut *conn).await.into_database("writing users to MAS")?;
Ok(())
})).await
})).boxed()
}
/// Write a batch of user passwords to the database.
@@ -580,14 +589,10 @@ impl<'conn> MasWriter<'conn> {
/// - If the database writer connection pool had an error.
#[allow(clippy::missing_panics_doc)] // not a real panic
#[tracing::instrument(skip_all, level = Level::DEBUG)]
pub async fn write_passwords(
pub fn write_passwords(
&mut self,
passwords: Vec<MasNewUserPassword>,
) -> Result<(), Error> {
if passwords.is_empty() {
return Ok(());
}
) -> BoxFuture<'_, Result<(), Error>> {
self.writer_pool.spawn_with_connection(move |conn| Box::pin(async move {
let mut user_password_ids: Vec<Uuid> = Vec::with_capacity(passwords.len());
let mut user_ids: Vec<Uuid> = Vec::with_capacity(passwords.len());
@@ -622,17 +627,14 @@ impl<'conn> MasWriter<'conn> {
).execute(&mut *conn).await.into_database("writing users to MAS")?;
Ok(())
})).await
})).boxed()
}
#[tracing::instrument(skip_all, level = Level::DEBUG)]
pub async fn write_email_threepids(
pub fn write_email_threepids(
&mut self,
threepids: Vec<MasNewEmailThreepid>,
) -> Result<(), Error> {
if threepids.is_empty() {
return Ok(());
}
) -> BoxFuture<'_, Result<(), Error>> {
self.writer_pool.spawn_with_connection(move |conn| {
Box::pin(async move {
let mut user_email_ids: Vec<Uuid> = Vec::with_capacity(threepids.len());
@@ -669,17 +671,14 @@ impl<'conn> MasWriter<'conn> {
Ok(())
})
}).await
}).boxed()
}
#[tracing::instrument(skip_all, level = Level::DEBUG)]
pub async fn write_unsupported_threepids(
pub fn write_unsupported_threepids(
&mut self,
threepids: Vec<MasNewUnsupportedThreepid>,
) -> Result<(), Error> {
if threepids.is_empty() {
return Ok(());
}
) -> BoxFuture<'_, Result<(), Error>> {
self.writer_pool.spawn_with_connection(move |conn| {
Box::pin(async move {
let mut user_ids: Vec<Uuid> = Vec::with_capacity(threepids.len());
@@ -700,8 +699,6 @@ impl<'conn> MasWriter<'conn> {
created_ats.push(created_at);
}
// `confirmed_at` is going to get removed in a future MAS release,
// so just populate with `created_at`
sqlx::query!(
r#"
INSERT INTO syn2mas__user_unsupported_third_party_ids
@@ -716,7 +713,53 @@ impl<'conn> MasWriter<'conn> {
Ok(())
})
}).await
}).boxed()
}
#[tracing::instrument(skip_all, level = Level::DEBUG)]
pub fn write_upstream_oauth_links(
&mut self,
links: Vec<MasNewUpstreamOauthLink>,
) -> BoxFuture<'_, Result<(), Error>> {
self.writer_pool.spawn_with_connection(move |conn| {
Box::pin(async move {
let mut link_ids: Vec<Uuid> = Vec::with_capacity(links.len());
let mut user_ids: Vec<Uuid> = Vec::with_capacity(links.len());
let mut upstream_provider_ids: Vec<Uuid> = Vec::with_capacity(links.len());
let mut subjects: Vec<String> = Vec::with_capacity(links.len());
let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(links.len());
for MasNewUpstreamOauthLink {
link_id,
user_id,
upstream_provider_id,
subject,
created_at,
} in links
{
link_ids.push(link_id);
user_ids.push(user_id);
upstream_provider_ids.push(upstream_provider_id);
subjects.push(subject);
created_ats.push(created_at);
}
sqlx::query!(
r#"
INSERT INTO syn2mas__upstream_oauth_links
(upstream_oauth_link_id, user_id, upstream_oauth_provider_id, subject, created_at)
SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::UUID[], $4::TEXT[], $5::TIMESTAMP WITH TIME ZONE[])
"#,
&link_ids[..],
&user_ids[..],
&upstream_provider_ids[..],
&subjects[..],
&created_ats[..],
).execute(&mut *conn).await.into_database("writing unsupported threepids to MAS")?;
Ok(())
})
}).boxed()
}
}
@@ -727,119 +770,60 @@ impl<'conn> MasWriter<'conn> {
// stream to two tables at once...)
const WRITE_BUFFER_BATCH_SIZE: usize = 4096;
pub struct MasUserWriteBuffer<'writer, 'conn> {
users: Vec<MasNewUser>,
passwords: Vec<MasNewUserPassword>,
writer: &'writer mut MasWriter<'conn>,
/// A function that can accept and flush buffers from a `MasWriteBuffer`.
/// Intended uses are the methods on `MasWriter` such as `write_users`.
type WriteBufferFlusher<'conn, T> =
for<'a> fn(&'a mut MasWriter<'conn>, Vec<T>) -> BoxFuture<'a, Result<(), Error>>;
/// A buffer for writing rows to the MAS database.
/// Generic over the type of rows.
///
/// # Panics
///
/// Panics if dropped before `finish()` has been called.
pub struct MasWriteBuffer<'conn, T> {
rows: Vec<T>,
flusher: WriteBufferFlusher<'conn, T>,
finished: bool,
}
impl<'writer, 'conn> MasUserWriteBuffer<'writer, 'conn> {
pub fn new(writer: &'writer mut MasWriter<'conn>) -> Self {
MasUserWriteBuffer {
users: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
passwords: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
writer,
impl<'conn, T> MasWriteBuffer<'conn, T> {
pub fn new(flusher: WriteBufferFlusher<'conn, T>) -> Self {
MasWriteBuffer {
rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
flusher,
finished: false,
}
}
pub async fn finish(mut self) -> Result<(), Error> {
self.flush_users().await?;
self.flush_passwords().await?;
pub async fn finish(mut self, writer: &mut MasWriter<'conn>) -> Result<(), Error> {
self.finished = true;
self.flush(writer).await?;
Ok(())
}
pub async fn flush_users(&mut self) -> Result<(), Error> {
// via copy: 13s
// not via copy: 14s
// difference probably gets worse with latency
self.writer
.write_users(std::mem::take(&mut self.users))
.await?;
self.users.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
Ok(())
}
pub async fn flush_passwords(&mut self) -> Result<(), Error> {
self.writer
.write_passwords(std::mem::take(&mut self.passwords))
.await?;
self.passwords.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
Ok(())
}
pub async fn write_user(&mut self, user: MasNewUser) -> Result<(), Error> {
self.users.push(user);
if self.users.len() >= WRITE_BUFFER_BATCH_SIZE {
self.flush_users().await?;
pub async fn flush(&mut self, writer: &mut MasWriter<'conn>) -> Result<(), Error> {
if self.rows.is_empty() {
return Ok(());
}
let rows = std::mem::take(&mut self.rows);
self.rows.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
(self.flusher)(writer, rows).await?;
Ok(())
}
pub async fn write_password(&mut self, password: MasNewUserPassword) -> Result<(), Error> {
self.passwords.push(password);
if self.passwords.len() >= WRITE_BUFFER_BATCH_SIZE {
self.flush_passwords().await?;
pub async fn write(&mut self, writer: &mut MasWriter<'conn>, row: T) -> Result<(), Error> {
self.rows.push(row);
if self.rows.len() >= WRITE_BUFFER_BATCH_SIZE {
self.flush(writer).await?;
}
Ok(())
}
}
pub struct MasThreepidWriteBuffer<'writer, 'conn> {
email: Vec<MasNewEmailThreepid>,
unsupported: Vec<MasNewUnsupportedThreepid>,
writer: &'writer mut MasWriter<'conn>,
}
impl<'writer, 'conn> MasThreepidWriteBuffer<'writer, 'conn> {
pub fn new(writer: &'writer mut MasWriter<'conn>) -> Self {
MasThreepidWriteBuffer {
email: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
unsupported: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
writer,
}
}
pub async fn finish(mut self) -> Result<(), Error> {
self.flush_emails().await?;
self.flush_unsupported().await?;
Ok(())
}
pub async fn flush_emails(&mut self) -> Result<(), Error> {
self.writer
.write_email_threepids(std::mem::take(&mut self.email))
.await?;
self.email.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
Ok(())
}
pub async fn flush_unsupported(&mut self) -> Result<(), Error> {
self.writer
.write_unsupported_threepids(std::mem::take(&mut self.unsupported))
.await?;
self.unsupported.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
Ok(())
}
pub async fn write_email(&mut self, user: MasNewEmailThreepid) -> Result<(), Error> {
self.email.push(user);
if self.email.len() >= WRITE_BUFFER_BATCH_SIZE {
self.flush_emails().await?;
}
Ok(())
}
pub async fn write_password(
&mut self,
unsupported: MasNewUnsupportedThreepid,
) -> Result<(), Error> {
self.unsupported.push(unsupported);
if self.unsupported.len() >= WRITE_BUFFER_BATCH_SIZE {
self.flush_unsupported().await?;
}
Ok(())
impl<T> Drop for MasWriteBuffer<'_, T> {
fn drop(&mut self) {
assert!(self.finished, "MasWriteBuffer dropped but not finished!");
}
}
@@ -855,7 +839,8 @@ mod test {
use crate::{
mas_writer::{
MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUser, MasNewUserPassword,
MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
MasNewUserPassword,
},
LockedMasDatabase, MasWriter,
};
@@ -1085,4 +1070,39 @@ mod test {
assert_db_snapshot!(&mut conn);
}
/// Tests writing a single user, with a link to an upstream provider.
/// There needs to be an upstream provider in the database already — in the
/// real migration, this is done by running a provider sync first.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR", fixtures("upstream_provider"))]
async fn test_write_user_with_upstream_provider_link(pool: PgPool) {
let mut conn = pool.acquire().await.unwrap();
let mut writer = make_mas_writer(&pool, &mut conn).await;
writer
.write_users(vec![MasNewUser {
user_id: Uuid::from_u128(1u128),
username: "alice".to_owned(),
created_at: DateTime::default(),
locked_at: None,
can_request_admin: false,
}])
.await
.expect("failed to write user");
writer
.write_upstream_oauth_links(vec![MasNewUpstreamOauthLink {
user_id: Uuid::from_u128(1u128),
link_id: Uuid::from_u128(3u128),
upstream_provider_id: Uuid::from_u128(4u128),
subject: "12345.67890".to_owned(),
created_at: DateTime::default(),
}])
.await
.expect("failed to write link");
writer.finish().await.expect("failed to finish MasWriter");
assert_db_snapshot!(&mut conn);
}
}

View File

@@ -0,0 +1,42 @@
---
source: crates/syn2mas/src/mas_writer/mod.rs
expression: db_snapshot
---
upstream_oauth_links:
- created_at: "1970-01-01 00:00:00+00"
human_account_name: ~
subject: "12345.67890"
upstream_oauth_link_id: 00000000-0000-0000-0000-000000000003
upstream_oauth_provider_id: 00000000-0000-0000-0000-000000000004
user_id: 00000000-0000-0000-0000-000000000001
upstream_oauth_providers:
- additional_parameters: ~
authorization_endpoint_override: ~
brand_name: ~
claims_imports: "{}"
client_id: someClientId
created_at: "2011-12-13 14:15:16+00"
disabled_at: ~
discovery_mode: oidc
encrypted_client_secret: ~
fetch_userinfo: "false"
human_name: ~
id_token_signed_response_alg: RS256
issuer: ~
jwks_uri_override: ~
pkce_mode: auto
response_mode: query
scope: openid
token_endpoint_auth_method: client_secret_basic
token_endpoint_override: ~
token_endpoint_signing_alg: ~
upstream_oauth_provider_id: 00000000-0000-0000-0000-000000000004
userinfo_endpoint_override: ~
userinfo_signed_response_alg: ~
users:
- can_request_admin: "false"
created_at: "1970-01-01 00:00:00+00"
locked_at: ~
primary_user_email_id: ~
user_id: 00000000-0000-0000-0000-000000000001
username: alice

View File

@@ -12,3 +12,4 @@ ALTER TABLE syn2mas__users RENAME TO users;
ALTER TABLE syn2mas__user_passwords RENAME TO user_passwords;
ALTER TABLE syn2mas__user_emails RENAME TO user_emails;
ALTER TABLE syn2mas__user_unsupported_third_party_ids RENAME TO user_unsupported_third_party_ids;
ALTER TABLE syn2mas__upstream_oauth_links RENAME TO upstream_oauth_links;

View File

@@ -41,3 +41,4 @@ ALTER TABLE users RENAME TO syn2mas__users;
ALTER TABLE user_passwords RENAME TO syn2mas__user_passwords;
ALTER TABLE user_emails RENAME TO syn2mas__user_emails;
ALTER TABLE user_unsupported_third_party_ids RENAME TO syn2mas__user_unsupported_third_party_ids;
ALTER TABLE upstream_oauth_links RENAME TO syn2mas__upstream_oauth_links;

View File

@@ -25,10 +25,12 @@ use uuid::Uuid;
use crate::{
mas_writer::{
self, MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUser, MasNewUserPassword,
MasThreepidWriteBuffer, MasUserWriteBuffer, MasWriter,
self, MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
MasNewUserPassword, MasWriteBuffer, MasWriter,
},
synapse_reader::{
self, ExtractLocalpartError, FullUserId, SynapseExternalId, SynapseThreepid, SynapseUser,
},
synapse_reader::{self, ExtractLocalpartError, FullUserId, SynapseThreepid, SynapseUser},
SynapseReader,
};
@@ -49,6 +51,16 @@ pub enum Error {
source: ExtractLocalpartError,
user: FullUserId,
},
#[error("user {user} was not found for migration but a row in {table} was found for them")]
MissingUserFromDependentTable { table: String, user: FullUserId },
#[error("missing a mapping for the auth provider with ID {synapse_id:?} (used by {user} and maybe other users)")]
MissingAuthProviderMapping {
/// `auth_provider` ID of the provider in Synapse, for which we have no
/// mapping
synapse_id: String,
/// a user that is using this auth provider
user: FullUserId,
},
}
struct UsersMigrated {
@@ -68,11 +80,13 @@ struct UsersMigrated {
///
/// - An underlying database access error, either to MAS or to Synapse.
/// - Invalid data in the Synapse database.
#[allow(clippy::implicit_hasher)]
pub async fn migrate(
synapse: &mut SynapseReader<'_>,
mas: &mut MasWriter<'_>,
server_name: &str,
rng: &mut impl RngCore,
provider_id_mapping: &HashMap<String, Uuid>,
) -> Result<(), Error> {
let counts = synapse.count_rows().await.into_synapse("counting users")?;
@@ -97,6 +111,16 @@ pub async fn migrate(
)
.await?;
migrate_external_ids(
synapse,
mas,
server_name,
rng,
&migrated_users.user_localparts_to_uuid,
provider_id_mapping,
)
.await?;
Ok(())
}
@@ -108,7 +132,8 @@ async fn migrate_users(
server_name: &str,
rng: &mut impl RngCore,
) -> Result<UsersMigrated, Error> {
let mut write_buffer = MasUserWriteBuffer::new(mas);
let mut user_buffer = MasWriteBuffer::new(MasWriter::write_users);
let mut password_buffer = MasWriteBuffer::new(MasWriter::write_passwords);
let mut users_stream = pin!(synapse.read_users());
// TODO is 1:1 capacity enough for a hashmap?
let mut user_localparts_to_uuid = HashMap::with_capacity(user_count_hint);
@@ -119,23 +144,24 @@ async fn migrate_users(
user_localparts_to_uuid.insert(CompactString::new(&mas_user.username), mas_user.user_id);
write_buffer
.write_user(mas_user)
user_buffer
.write(mas, mas_user)
.await
.into_mas("writing user")?;
if let Some(mas_password) = mas_password_opt {
write_buffer
.write_password(mas_password)
password_buffer
.write(mas, mas_password)
.await
.into_mas("writing password")?;
}
}
write_buffer
.finish()
user_buffer.finish(mas).await.into_mas("writing users")?;
password_buffer
.finish(mas)
.await
.into_mas("writing users & passwords")?;
.into_mas("writing passwords")?;
Ok(UsersMigrated {
user_localparts_to_uuid,
@@ -150,7 +176,8 @@ async fn migrate_threepids(
rng: &mut impl RngCore,
user_localparts_to_uuid: &HashMap<CompactString, Uuid>,
) -> Result<(), Error> {
let mut write_buffer = MasThreepidWriteBuffer::new(mas);
let mut email_buffer = MasWriteBuffer::new(MasWriter::write_email_threepids);
let mut unsupported_buffer = MasWriteBuffer::new(MasWriter::write_unsupported_threepids);
let mut users_stream = pin!(synapse.read_threepids());
while let Some(threepid_res) = users_stream.next().await {
@@ -167,36 +194,121 @@ async fn migrate_threepids(
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else {
todo!()
return Err(Error::MissingUserFromDependentTable {
table: "user_threepids".to_owned(),
user: synapse_user_id,
});
};
if medium == "email" {
write_buffer
.write_email(MasNewEmailThreepid {
user_id,
user_email_id: Uuid::from(Ulid::from_datetime_with_source(
created_at.into(),
rng,
)),
email: address,
created_at,
})
email_buffer
.write(
mas,
MasNewEmailThreepid {
user_id,
user_email_id: Uuid::from(Ulid::from_datetime_with_source(
created_at.into(),
rng,
)),
email: address,
created_at,
},
)
.await
.into_mas("writing email")?;
} else {
write_buffer
.write_password(MasNewUnsupportedThreepid {
user_id,
medium,
address,
created_at,
})
unsupported_buffer
.write(
mas,
MasNewUnsupportedThreepid {
user_id,
medium,
address,
created_at,
},
)
.await
.into_mas("writing unsupported threepid")?;
}
}
write_buffer.finish().await.into_mas("writing threepids")?;
email_buffer
.finish(mas)
.await
.into_mas("writing email threepids")?;
unsupported_buffer
.finish(mas)
.await
.into_mas("writing unsupported threepids")?;
Ok(())
}
/// # Parameters
///
/// - `provider_id_mapping`: mapping from Synapse `auth_provider` ID to UUID of
/// the upstream provider in MAS.
#[tracing::instrument(skip_all, level = Level::INFO)]
async fn migrate_external_ids(
synapse: &mut SynapseReader<'_>,
mas: &mut MasWriter<'_>,
server_name: &str,
rng: &mut impl RngCore,
user_localparts_to_uuid: &HashMap<CompactString, Uuid>,
provider_id_mapping: &HashMap<String, Uuid>,
) -> Result<(), Error> {
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_upstream_oauth_links);
let mut extids_stream = pin!(synapse.read_user_external_ids());
while let Some(extid_res) = extids_stream.next().await {
let SynapseExternalId {
user_id: synapse_user_id,
auth_provider,
external_id: subject,
} = extid_res.into_synapse("reading external ID")?;
let username = synapse_user_id
.extract_localpart(server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else {
return Err(Error::MissingUserFromDependentTable {
table: "user_external_ids".to_owned(),
user: synapse_user_id,
});
};
let Some(&upstream_provider_id) = provider_id_mapping.get(&auth_provider) else {
return Err(Error::MissingAuthProviderMapping {
synapse_id: auth_provider,
user: synapse_user_id,
});
};
// To save having to store user creation times, extract it from the ULID
// This gives millisecond precision — good enough.
let user_created_ts = Ulid::from(user_id).datetime();
let link_id: Uuid = Ulid::from_datetime_with_source(user_created_ts, rng).into();
write_buffer
.write(
mas,
MasNewUpstreamOauthLink {
link_id,
user_id,
upstream_provider_id,
subject,
created_at: user_created_ts.into(),
},
)
.await
.into_mas("failed to write upstream link")?;
}
write_buffer
.finish(mas)
.await
.into_mas("writing threepids")?;
Ok(())
}

View File

@@ -0,0 +1,12 @@
INSERT INTO user_external_ids
(
user_id,
auth_provider,
external_id
)
VALUES
(
'@alice:example.com',
'oidc-raasu',
'871.syn30'
);

View File

@@ -8,6 +8,8 @@
//! This module provides facilities for streaming relevant types of database
//! records from a Synapse database.
use std::fmt::Display;
use chrono::{DateTime, Utc};
use futures_util::{Stream, TryStreamExt};
use sqlx::{query, Acquire, FromRow, PgConnection, Postgres, Row, Transaction, Type};
@@ -30,6 +32,12 @@ pub enum Error {
#[derive(Clone, Debug, sqlx::Decode, PartialEq, Eq, PartialOrd, Ord)]
pub struct FullUserId(pub String);
impl Display for FullUserId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl Type<Postgres> for FullUserId {
fn type_info() -> <sqlx::Postgres as sqlx::Database>::TypeInfo {
<String as Type<Postgres>>::type_info()
@@ -193,13 +201,21 @@ pub struct SynapseThreepid {
pub added_at: MillisecondsTimestamp,
}
/// Row of the `user_external_ids` table in Synapse.
#[derive(Clone, Debug, FromRow, PartialEq, Eq, PartialOrd, Ord)]
pub struct SynapseExternalId {
pub user_id: FullUserId,
pub auth_provider: String,
pub external_id: String,
}
/// List of Synapse tables that we should acquire an `EXCLUSIVE` lock on.
///
/// This is a safety measure against other processes changing the data
/// underneath our feet. It's still not a good idea to run Synapse at the same
/// time as the migration.
// TODO not complete!
const TABLES_TO_LOCK: &[&str] = &["users"];
const TABLES_TO_LOCK: &[&str] = &["users", "user_threepids", "user_external_ids"];
/// Number of migratable rows in various Synapse tables.
/// Used to estimate progress.
@@ -319,6 +335,21 @@ impl<'conn> SynapseReader<'conn> {
.fetch(&mut *self.txn)
.map_err(|err| err.into_database("reading Synapse threepids"))
}
/// Read associations between Synapse users and external identity providers
pub fn read_user_external_ids(
&mut self,
) -> impl Stream<Item = Result<SynapseExternalId, Error>> + '_ {
sqlx::query_as(
"
SELECT
user_id, auth_provider, external_id
FROM user_external_ids
",
)
.fetch(&mut *self.txn)
.map_err(|err| err.into_database("reading Synapse user external IDs"))
}
}
#[cfg(test)]
@@ -330,7 +361,7 @@ mod test {
use sqlx::{migrate::Migrator, PgPool};
use crate::{
synapse_reader::{SynapseThreepid, SynapseUser},
synapse_reader::{SynapseExternalId, SynapseThreepid, SynapseUser},
SynapseReader,
};
@@ -368,4 +399,20 @@ mod test {
assert_debug_snapshot!(threepids);
}
#[sqlx::test(migrator = "MIGRATOR", fixtures("user_alice", "external_ids_alice"))]
async fn test_read_external_ids(pool: PgPool) {
let mut conn = pool.acquire().await.expect("failed to get connection");
let mut reader = SynapseReader::new(&mut conn, false)
.await
.expect("failed to make SynapseReader");
let external_ids: BTreeSet<SynapseExternalId> = reader
.read_user_external_ids()
.try_collect()
.await
.expect("failed to read Synapse external user IDs");
assert_debug_snapshot!(external_ids);
}
}

View File

@@ -0,0 +1,13 @@
---
source: crates/syn2mas/src/synapse_reader/mod.rs
expression: external_ids
---
{
SynapseExternalId {
user_id: FullUserId(
"@alice:example.com",
),
auth_provider: "oidc-raasu",
external_id: "871.syn30",
},
}

View File

@@ -0,0 +1,12 @@
-- Copyright 2025 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only
-- Please see LICENSE in the repository root for full details.
-- Brings in the `user_external_ids` table from Synapse
CREATE TABLE user_external_ids (
auth_provider text NOT NULL,
external_id text NOT NULL,
user_id text NOT NULL
);