Add MasWriter support for compat refresh tokens + some migration progress

This commit is contained in:
Olivier 'reivilibre
2025-01-30 16:33:19 +00:00
committed by Quentin Gliech
parent 8577d3f9fe
commit ca32c5ebff
10 changed files with 349 additions and 25 deletions
Generated
+1
View File
@@ -6113,6 +6113,7 @@ dependencies = [
"futures-util",
"insta",
"mas-config",
"mas-storage",
"mas-storage-pg",
"rand",
"serde",
+2
View File
@@ -223,6 +223,7 @@ impl Options {
}
let mut writer = MasWriter::new(mas_connection, writer_mas_connections).await?;
let clock = SystemClock::default();
// TODO is this rng ok?
#[allow(clippy::disallowed_methods)]
let mut rng = thread_rng();
@@ -233,6 +234,7 @@ impl Options {
&mut reader,
&mut writer,
&mas_matrix.homeserver,
&clock,
&mut rng,
&provider_id_mappings,
)
@@ -0,0 +1,18 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO syn2mas__compat_refresh_tokens (\n compat_refresh_token_id,\n compat_session_id,\n compat_access_token_id,\n refresh_token,\n created_at)\n SELECT * FROM UNNEST(\n $1::UUID[],\n $2::UUID[],\n $3::UUID[],\n $4::TEXT[],\n $5::TIMESTAMP WITH TIME ZONE[])\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"UuidArray",
"UuidArray",
"UuidArray",
"TextArray",
"TimestamptzArray"
]
},
"nullable": []
},
"hash": "88975196c4c174d464b33aa015ce5d8cac3836701fc24922f4f0e8b98d330796"
}
+1
View File
@@ -28,6 +28,7 @@ uuid = "1.10.0"
ulid = { workspace = true, features = ["uuid"] }
mas-config.workspace = true
mas-storage.workspace = true
[dev-dependencies]
mas-storage-pg.workspace = true
+136 -7
View File
@@ -233,7 +233,7 @@ pub struct MasNewUpstreamOauthLink {
pub struct MasNewCompatSession {
pub session_id: Uuid,
pub user_id: Uuid,
pub device_id: String,
pub device_id: Option<String>,
pub human_name: Option<String>,
pub created_at: DateTime<Utc>,
pub is_synapse_admin: bool,
@@ -250,6 +250,14 @@ pub struct MasNewCompatAccessToken {
pub expires_at: Option<DateTime<Utc>>,
}
pub struct MasNewCompatRefreshToken {
pub refresh_token_id: Uuid,
pub session_id: Uuid,
pub access_token_id: Uuid,
pub refresh_token: String,
pub created_at: DateTime<Utc>,
}
/// The 'version' of the password hashing scheme used for passwords when they
/// are migrated from Synapse to MAS.
/// This is version 1, as in the previous syn2mas script.
@@ -795,7 +803,7 @@ impl<'conn> MasWriter<'conn> {
Box::pin(async move {
let mut session_ids: Vec<Uuid> = Vec::with_capacity(sessions.len());
let mut user_ids: Vec<Uuid> = Vec::with_capacity(sessions.len());
let mut device_ids: Vec<String> = Vec::with_capacity(sessions.len());
let mut device_ids: Vec<Option<String>> = Vec::with_capacity(sessions.len());
let mut human_names: Vec<Option<String>> = Vec::with_capacity(sessions.len());
let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(sessions.len());
let mut is_synapse_admins: Vec<bool> = Vec::with_capacity(sessions.len());
@@ -845,7 +853,7 @@ impl<'conn> MasWriter<'conn> {
"#,
&session_ids[..],
&user_ids[..],
&device_ids[..],
&device_ids[..] as &[Option<String>],
&human_names[..] as &[Option<String>],
&created_ats[..],
&is_synapse_admins[..],
@@ -925,6 +933,66 @@ impl<'conn> MasWriter<'conn> {
})
.boxed()
}
#[tracing::instrument(skip_all, level = Level::DEBUG)]
pub fn write_compat_refresh_tokens(
&mut self,
tokens: Vec<MasNewCompatRefreshToken>,
) -> BoxFuture<'_, Result<(), Error>> {
self.writer_pool
.spawn_with_connection(move |conn| {
Box::pin(async move {
let mut refresh_token_ids: Vec<Uuid> = Vec::with_capacity(tokens.len());
let mut session_ids: Vec<Uuid> = Vec::with_capacity(tokens.len());
let mut access_token_ids: Vec<Uuid> = Vec::with_capacity(tokens.len());
let mut refresh_tokens: Vec<String> = Vec::with_capacity(tokens.len());
let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(tokens.len());
for MasNewCompatRefreshToken {
refresh_token_id,
session_id,
access_token_id,
refresh_token,
created_at,
} in tokens
{
refresh_token_ids.push(refresh_token_id);
session_ids.push(session_id);
access_token_ids.push(access_token_id);
refresh_tokens.push(refresh_token);
created_ats.push(created_at);
}
sqlx::query!(
r#"
INSERT INTO syn2mas__compat_refresh_tokens (
compat_refresh_token_id,
compat_session_id,
compat_access_token_id,
refresh_token,
created_at)
SELECT * FROM UNNEST(
$1::UUID[],
$2::UUID[],
$3::UUID[],
$4::TEXT[],
$5::TIMESTAMP WITH TIME ZONE[])
"#,
&refresh_token_ids[..],
&session_ids[..],
&access_token_ids[..],
&refresh_tokens[..],
&created_ats[..],
)
.execute(&mut *conn)
.await
.into_database("writing compat refresh tokens to MAS")?;
Ok(())
})
})
.boxed()
}
}
// How many entries to buffer at once, before writing a batch of rows to the
@@ -1003,8 +1071,9 @@ mod test {
use crate::{
mas_writer::{
MasNewCompatAccessToken, MasNewCompatSession, MasNewEmailThreepid,
MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser, MasNewUserPassword,
MasNewCompatAccessToken, MasNewCompatRefreshToken, MasNewCompatSession,
MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
MasNewUserPassword,
},
LockedMasDatabase, MasWriter,
};
@@ -1292,7 +1361,7 @@ mod test {
user_id: Uuid::from_u128(1u128),
session_id: Uuid::from_u128(5u128),
created_at: DateTime::default(),
device_id: "ADEVICE".to_owned(),
device_id: Some("ADEVICE".to_owned()),
human_name: Some("alice's pinephone".to_owned()),
is_synapse_admin: true,
last_active_at: Some(DateTime::default()),
@@ -1329,7 +1398,7 @@ mod test {
user_id: Uuid::from_u128(1u128),
session_id: Uuid::from_u128(5u128),
created_at: DateTime::default(),
device_id: "ADEVICE".to_owned(),
device_id: Some("ADEVICE".to_owned()),
human_name: None,
is_synapse_admin: false,
last_active_at: None,
@@ -1354,4 +1423,64 @@ mod test {
assert_db_snapshot!(&mut conn);
}
/// Tests writing a single user, with a device, an access token and a
/// refresh token.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_write_user_with_refresh_token(pool: PgPool) {
let mut conn = pool.acquire().await.unwrap();
let mut writer = make_mas_writer(&pool, &mut conn).await;
writer
.write_users(vec![MasNewUser {
user_id: Uuid::from_u128(1u128),
username: "alice".to_owned(),
created_at: DateTime::default(),
locked_at: None,
can_request_admin: false,
}])
.await
.expect("failed to write user");
writer
.write_compat_sessions(vec![MasNewCompatSession {
user_id: Uuid::from_u128(1u128),
session_id: Uuid::from_u128(5u128),
created_at: DateTime::default(),
device_id: Some("ADEVICE".to_owned()),
human_name: None,
is_synapse_admin: false,
last_active_at: None,
last_active_ip: None,
user_agent: None,
}])
.await
.expect("failed to write compat session");
writer
.write_compat_access_tokens(vec![MasNewCompatAccessToken {
token_id: Uuid::from_u128(6u128),
session_id: Uuid::from_u128(5u128),
access_token: "syt_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
created_at: DateTime::default(),
expires_at: None,
}])
.await
.expect("failed to write access token");
writer
.write_compat_refresh_tokens(vec![MasNewCompatRefreshToken {
refresh_token_id: Uuid::from_u128(7u128),
session_id: Uuid::from_u128(5u128),
access_token_id: Uuid::from_u128(6u128),
refresh_token: "syr_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
created_at: DateTime::default(),
}])
.await
.expect("failed to write refresh token");
writer.finish().await.expect("failed to finish MasWriter");
assert_db_snapshot!(&mut conn);
}
}
@@ -0,0 +1,36 @@
---
source: crates/syn2mas/src/mas_writer/mod.rs
expression: db_snapshot
---
compat_access_tokens:
- access_token: syt_zxcvzxcvzxcvzxcv_zxcv
compat_access_token_id: 00000000-0000-0000-0000-000000000006
compat_session_id: 00000000-0000-0000-0000-000000000005
created_at: "1970-01-01 00:00:00+00"
expires_at: ~
compat_refresh_tokens:
- compat_access_token_id: 00000000-0000-0000-0000-000000000006
compat_refresh_token_id: 00000000-0000-0000-0000-000000000007
compat_session_id: 00000000-0000-0000-0000-000000000005
consumed_at: ~
created_at: "1970-01-01 00:00:00+00"
refresh_token: syr_zxcvzxcvzxcvzxcv_zxcv
compat_sessions:
- compat_session_id: 00000000-0000-0000-0000-000000000005
created_at: "1970-01-01 00:00:00+00"
device_id: ADEVICE
finished_at: ~
human_name: ~
is_synapse_admin: "false"
last_active_at: ~
last_active_ip: ~
user_agent: ~
user_id: 00000000-0000-0000-0000-000000000001
user_session_id: ~
users:
- can_request_admin: "false"
created_at: "1970-01-01 00:00:00+00"
locked_at: ~
primary_user_email_id: ~
user_id: 00000000-0000-0000-0000-000000000001
username: alice
+153 -11
View File
@@ -19,6 +19,7 @@ use std::{
use chrono::{DateTime, Utc};
use compact_str::CompactString;
use futures_util::StreamExt as _;
use mas_storage::Clock;
use rand::RngCore;
use thiserror::Error;
use thiserror_ext::ContextInto;
@@ -28,12 +29,13 @@ use uuid::Uuid;
use crate::{
mas_writer::{
self, MasNewCompatSession, MasNewEmailThreepid, MasNewUnsupportedThreepid,
MasNewUpstreamOauthLink, MasNewUser, MasNewUserPassword, MasWriteBuffer, MasWriter,
self, MasNewCompatAccessToken, MasNewCompatSession, MasNewEmailThreepid,
MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser, MasNewUserPassword,
MasWriteBuffer, MasWriter,
},
synapse_reader::{
self, ExtractLocalpartError, FullUserId, SynapseDevice, SynapseExternalId, SynapseThreepid,
SynapseUser,
self, ExtractLocalpartError, FullUserId, SynapseAccessToken, SynapseDevice,
SynapseExternalId, SynapseRefreshToken, SynapseThreepid, SynapseUser,
},
SynapseReader,
};
@@ -92,6 +94,7 @@ pub async fn migrate(
synapse: &mut SynapseReader<'_>,
mas: &mut MasWriter<'_>,
server_name: &str,
clock: &dyn Clock,
rng: &mut impl RngCore,
provider_id_mapping: &HashMap<String, Uuid>,
) -> Result<(), Error> {
@@ -137,7 +140,18 @@ pub async fn migrate(
.expect("More than usize::MAX devices — unable to handle this many!"),
);
migrate_access_and_refresh_tokens(
migrate_access_tokens(
synapse,
mas,
server_name,
clock,
rng,
&migrated_users.user_localparts_to_uuid,
&mut devices_to_compat_sessions,
)
.await?;
migrate_refresh_tokens(
synapse,
mas,
server_name,
@@ -433,7 +447,7 @@ async fn migrate_devices(
MasNewCompatSession {
session_id,
user_id,
device_id,
device_id: Some(device_id),
human_name: display_name,
created_at,
is_synapse_admin: synapse_admins.contains(&user_id),
@@ -455,7 +469,110 @@ async fn migrate_devices(
}
#[tracing::instrument(skip_all, level = Level::INFO)]
async fn migrate_access_and_refresh_tokens(
async fn migrate_access_tokens(
synapse: &mut SynapseReader<'_>,
mas: &mut MasWriter<'_>,
server_name: &str,
clock: &dyn Clock,
rng: &mut impl RngCore,
user_localparts_to_uuid: &HashMap<CompactString, Uuid>,
devices: &mut HashMap<(Uuid, CompactString), Uuid>,
) -> Result<(), Error> {
let mut token_stream = pin!(synapse.read_access_tokens());
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
let mut deviceless_session_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
while let Some(token_res) = token_stream.next().await {
let SynapseAccessToken {
user_id: synapse_user_id,
device_id,
token,
valid_until_ms,
last_validated,
} = token_res.into_synapse("reading Synapse access token")?;
let username = synapse_user_id
.extract_localpart(server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else {
return Err(Error::MissingUserFromDependentTable {
table: "devices".to_owned(),
user: synapse_user_id,
});
};
// It's not always accurate, but last_validated is *often* the creation time of
// the device If we don't have one, then use the current time as a
// fallback.
let created_at = last_validated.map_or_else(|| clock.now(), DateTime::from);
let session_id = if let Some(device_id) = device_id {
// Use the existing device_id if this is the second token for a device
*devices
.entry((user_id, CompactString::new(&device_id)))
.or_insert_with(|| {
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng))
})
} else {
// If this is a deviceless access token, create a deviceless compat session
// for it (since otherwise we won't create one whilst migrating devices)
let deviceless_session_id =
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
deviceless_session_write_buffer
.write(
mas,
MasNewCompatSession {
session_id: deviceless_session_id,
user_id,
device_id: None,
human_name: None,
created_at,
is_synapse_admin: false,
last_active_at: None,
last_active_ip: None,
user_agent: None,
},
)
.await
.into_mas("failed to write deviceless compat sessions")?;
deviceless_session_id
};
let token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
// TODO skip access tokens for deactivated users
write_buffer
.write(
mas,
MasNewCompatAccessToken {
token_id,
session_id,
access_token: token,
created_at,
expires_at: valid_until_ms.map(DateTime::from),
},
)
.await
.into_mas("writing compat access tokens")?;
}
write_buffer
.finish(mas)
.await
.into_mas("writing compat access tokens")?;
deviceless_session_write_buffer
.finish(mas)
.await
.into_mas("writing deviceless compat sessions")?;
Ok(())
}
#[tracing::instrument(skip_all, level = Level::INFO)]
async fn migrate_refresh_tokens(
synapse: &mut SynapseReader<'_>,
mas: &mut MasWriter<'_>,
server_name: &str,
@@ -463,10 +580,35 @@ async fn migrate_access_and_refresh_tokens(
user_localparts_to_uuid: &HashMap<CompactString, Uuid>,
devices: &mut HashMap<(Uuid, CompactString), Uuid>,
) -> Result<(), Error> {
let mut access_token_stream = pin!(synapse.read_access_tokens());
// let mut write_buffer =
// MasWriteBuffer::new(MasWriter::write_compat_access_token);
todo!();
let mut token_stream = pin!(synapse.read_refresh_tokens());
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_refresh_tokens);
while let Some(token_res) = token_stream.next().await {
let SynapseRefreshToken {
user_id: synapse_user_id,
device_id,
token,
id,
} = token_res.into_synapse("reading Synapse refresh token")?;
let username = synapse_user_id
.extract_localpart(server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else {
return Err(Error::MissingUserFromDependentTable {
table: "devices".to_owned(),
user: synapse_user_id,
});
};
todo!()
}
write_buffer
.finish(mas)
.await
.into_mas("writing compat refresh tokens")?;
Ok(())
}
+2 -3
View File
@@ -12,7 +12,7 @@ use std::fmt::Display;
use chrono::{DateTime, Utc};
use futures_util::{Stream, TryStreamExt};
use sqlx::{query, Acquire, FromRow, PgConnection, Postgres, Row, Transaction, Type};
use sqlx::{query, Acquire, FromRow, PgConnection, Postgres, Transaction, Type};
use thiserror::Error;
use thiserror_ext::ContextInto;
@@ -228,7 +228,6 @@ pub struct SynapseAccessToken {
pub token: String,
pub valid_until_ms: Option<MillisecondsTimestamp>,
pub last_validated: Option<MillisecondsTimestamp>,
pub refresh_token_id: Option<i64>,
}
/// Row of the `refresh_tokens` table in Synapse.
@@ -426,7 +425,7 @@ impl<'conn> SynapseReader<'conn> {
sqlx::query_as(
"
SELECT
at0.user_id, at0.device_id, at0.token, at0.valid_until_ms, at0.last_validated, at0.refresh_token_id
at0.user_id, at0.device_id, at0.token, at0.valid_until_ms, at0.last_validated
FROM access_tokens at0
LEFT JOIN refresh_tokens rt0 ON at0.refresh_token_id = rt0.id
LEFT JOIN access_tokens at1 ON rt0.next_token_id = at1.refresh_token_id
@@ -13,8 +13,5 @@ expression: access_tokens
token: "syt_AAAAAAAAAAAAAA_AAAA",
valid_until_ms: None,
last_validated: None,
refresh_token_id: Some(
8,
),
},
}
@@ -13,6 +13,5 @@ expression: access_tokens
token: "syt_aaaaaaaaaaaaaa_aaaa",
valid_until_ms: None,
last_validated: None,
refresh_token_id: None,
},
}