diff --git a/matrix-authentication-service/src/filters/database.rs b/matrix-authentication-service/src/filters/database.rs new file mode 100644 index 000000000..30aeb2828 --- /dev/null +++ b/matrix-authentication-service/src/filters/database.rs @@ -0,0 +1,54 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::convert::Infallible; + +use sqlx::{pool::PoolConnection, PgPool, Postgres, Transaction}; +use warp::{Filter, Rejection}; + +use crate::errors::WrapError; + +fn with_pool( + pool: &PgPool, +) -> impl Filter + Clone + Send + Sync + 'static { + let pool = pool.clone(); + warp::any().map(move || pool.clone()) +} + +pub fn with_connection( + pool: &PgPool, +) -> impl Filter,), Error = Rejection> + Clone + Send + Sync + 'static +{ + with_pool(pool).and_then(acquire_connection) +} + +async fn acquire_connection(pool: PgPool) -> Result, Rejection> { + let conn = pool.acquire().await.wrap_error()?; + Ok(conn) +} + +pub fn with_transaction( + pool: &PgPool, +) -> impl Filter,), Error = Rejection> + + Clone + + Send + + Sync + + 'static { + with_pool(pool).and_then(acquire_transaction) +} + +async fn acquire_transaction(pool: PgPool) -> Result, Rejection> { + let txn = pool.begin().await.wrap_error()?; + Ok(txn) +} diff --git a/matrix-authentication-service/src/filters/mod.rs b/matrix-authentication-service/src/filters/mod.rs index 19c7cd6d5..14090b0fd 100644 --- a/matrix-authentication-service/src/filters/mod.rs +++ b/matrix-authentication-service/src/filters/mod.rs @@ -15,23 +15,16 @@ pub mod csrf; // mod errors; pub mod cookies; +pub mod database; pub mod session; use std::convert::Infallible; -use sqlx::PgPool; use warp::Filter; pub use self::csrf::CsrfToken; use crate::templates::Templates; -pub fn with_pool( - pool: &PgPool, -) -> impl Filter + Clone + Send + Sync + 'static { - let pool = pool.clone(); - warp::any().map(move || pool.clone()) -} - pub fn with_templates( templates: &Templates, ) -> impl Filter + Clone + Send + Sync + 'static { diff --git a/matrix-authentication-service/src/filters/session.rs b/matrix-authentication-service/src/filters/session.rs index ec5238c58..3b7266414 100644 --- a/matrix-authentication-service/src/filters/session.rs +++ b/matrix-authentication-service/src/filters/session.rs @@ -14,12 +14,12 @@ use headers::SetCookie; use serde::{Deserialize, Serialize}; -use sqlx::{Executor, PgPool, Postgres}; +use sqlx::{pool::PoolConnection, Executor, PgPool, Postgres}; use warp::{filters::BoxedFilter, Filter, Rejection, Reply}; use super::{ cookies::{encrypted, maybe_encrypted, save_encrypted, WithTypedHeader}, - with_pool, + database::with_connection, }; use crate::{ config::CookiesConfig, @@ -53,15 +53,17 @@ pub fn with_optional_session( ) -> impl Filter,), Error = Rejection> + Clone + Send + Sync + 'static { maybe_encrypted("session", cookies_config) - .and(with_pool(pool)) - .and_then(|maybe_session: Option, pool: PgPool| async move { - let maybe_session_info = if let Some(session) = maybe_session { - session.load_session_info(&pool).await.ok() - } else { - None - }; - Ok::<_, Rejection>(maybe_session_info) - }) + .and(with_connection(pool)) + .and_then( + |maybe_session: Option, mut conn: PoolConnection| async move { + let maybe_session_info = if let Some(session) = maybe_session { + session.load_session_info(&mut conn).await.ok() + } else { + None + }; + Ok::<_, Rejection>(maybe_session_info) + }, + ) } pub fn with_session( @@ -69,11 +71,13 @@ pub fn with_session( cookies_config: &CookiesConfig, ) -> impl Filter + Clone + Send + Sync + 'static { encrypted("session", cookies_config) - .and(with_pool(pool)) - .and_then(|session: Session, pool: PgPool| async move { - let session_info = session.load_session_info(&pool).await.wrap_error()?; - Ok::<_, Rejection>(session_info) - }) + .and(with_connection(pool)) + .and_then( + |session: Session, mut conn: PoolConnection| async move { + let session_info = session.load_session_info(&mut conn).await.wrap_error()?; + Ok::<_, Rejection>(session_info) + }, + ) } pub fn save_session( diff --git a/matrix-authentication-service/src/handlers/health.rs b/matrix-authentication-service/src/handlers/health.rs index 623b295c1..6b37d158d 100644 --- a/matrix-authentication-service/src/handlers/health.rs +++ b/matrix-authentication-service/src/handlers/health.rs @@ -14,25 +14,25 @@ use hyper::header::CONTENT_TYPE; use mime::TEXT_PLAIN; -use sqlx::PgPool; +use sqlx::{pool::PoolConnection, PgPool, Postgres}; use tracing::{info_span, Instrument}; use warp::{reply::with_header, Filter, Rejection, Reply}; -use crate::{errors::WrapError, filters::with_pool}; +use crate::{errors::WrapError, filters::database::with_connection}; pub fn filter( pool: &PgPool, ) -> impl Filter + Clone + Send + Sync + 'static { warp::get() .and(warp::path("health")) - .and(with_pool(pool)) + .and(with_connection(pool)) .and_then(get) } -async fn get(pool: PgPool) -> Result { +async fn get(mut conn: PoolConnection) -> Result { sqlx::query("SELECT $1") .bind(1_i64) - .execute(&pool) + .execute(&mut conn) .instrument(info_span!("DB health")) .await .wrap_error()?; diff --git a/matrix-authentication-service/src/handlers/oauth2/authorization.rs b/matrix-authentication-service/src/handlers/oauth2/authorization.rs index 2592a0be5..bac0f5416 100644 --- a/matrix-authentication-service/src/handlers/oauth2/authorization.rs +++ b/matrix-authentication-service/src/handlers/oauth2/authorization.rs @@ -32,7 +32,7 @@ use oauth2_types::{ }, }; use serde::{Deserialize, Serialize}; -use sqlx::PgPool; +use sqlx::{PgPool, Postgres, Transaction}; use url::Url; use warp::{ redirect::see_other, @@ -44,8 +44,9 @@ use crate::{ config::{CookiesConfig, OAuth2ClientConfig, OAuth2Config}, errors::WrapError, filters::{ + database::with_transaction, session::{with_optional_session, with_session}, - with_pool, with_templates, + with_templates, }, handlers::views::LoginRequest, storage::{ @@ -164,7 +165,7 @@ pub fn filter( .map(move || clients.clone()) .and(warp::query()) .and(with_optional_session(pool, cookies_config)) - .and(with_pool(pool)) + .and(with_transaction(pool)) .and(with_templates(templates)) .and_then(get); @@ -172,7 +173,7 @@ pub fn filter( .and(warp::path!("oauth2" / "authorize" / "step")) .and(warp::query().map(|s: StepRequest| s.id)) .and(with_session(pool, cookies_config)) - .and(with_pool(pool)) + .and(with_transaction(pool)) .and(with_templates(templates)) .and_then(step); @@ -183,7 +184,7 @@ async fn get( clients: Vec, params: Params, maybe_session: Option, - pool: PgPool, + mut txn: Transaction<'_, Postgres>, templates: Templates, ) -> Result, Rejection> { // First, find out what client it is @@ -198,8 +199,6 @@ async fn get( .resolve_redirect_uri(¶ms.auth.redirect_uri) .wrap_error()?; - // Start a DB transaction - let mut txn = pool.begin().await.wrap_error()?; let maybe_session_id = maybe_session.as_ref().map(SessionInfo::key); let scope: String = { @@ -237,14 +236,15 @@ async fn get( .wrap_error()?; }; - // Do we have a user in this session, with a last authentication time that - // matches the requirement? + // Do we already have a user session for this oauth2 session? let user_session = oauth2_session.fetch_session(&mut txn).await.wrap_error()?; - txn.commit().await.wrap_error()?; if let Some(user_session) = user_session { - step(oauth2_session.id, user_session, pool, templates).await + step(oauth2_session.id, user_session, txn, templates).await } else { + // If not, redirect the user to the login page + txn.commit().await.wrap_error()?; + let next = StepRequest::new(oauth2_session.id) .build_uri() .wrap_error()? @@ -280,12 +280,9 @@ impl StepRequest { async fn step( oauth2_session_id: i64, user_session: SessionInfo, - pool: PgPool, + mut txn: Transaction<'_, Postgres>, templates: Templates, ) -> Result, Rejection> { - // Start a DB transaction - let mut txn = pool.begin().await.wrap_error()?; - let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id) .await .wrap_error()?; diff --git a/matrix-authentication-service/src/handlers/views/login.rs b/matrix-authentication-service/src/handlers/views/login.rs index a95cdccd0..5c3c63e82 100644 --- a/matrix-authentication-service/src/handlers/views/login.rs +++ b/matrix-authentication-service/src/handlers/views/login.rs @@ -16,7 +16,7 @@ use std::convert::TryFrom; use hyper::http::uri::{Parts, PathAndQuery, Uri}; use serde::{Deserialize, Serialize}; -use sqlx::PgPool; +use sqlx::{pool::PoolConnection, PgPool, Postgres}; use warp::{reply::html, wrap_fn, Filter, Rejection, Reply}; use crate::{ @@ -24,8 +24,9 @@ use crate::{ errors::WrapError, filters::{ csrf::{protected_form, save_csrf_token, updated_csrf_token}, - session::save_session, - with_pool, with_templates, CsrfToken, + database::with_connection, + session::{save_session, with_optional_session}, + with_templates, CsrfToken, }, storage::{login, SessionInfo}, templates::{TemplateContext, Templates}, @@ -51,6 +52,22 @@ impl LoginRequest { })?; Ok(uri) } + + fn redirect(self) -> Result { + let uri: Uri = Uri::from_parts({ + let mut parts = Parts::default(); + parts.path_and_query = Some( + self.next + .map(warp::http::uri::PathAndQuery::try_from) + .transpose() + .wrap_error()? + .unwrap_or_else(|| PathAndQuery::from_static("/")), + ); + parts + }) + .wrap_error()?; + Ok(warp::redirect::see_other(uri)) + } } #[derive(Deserialize)] @@ -68,12 +85,14 @@ pub(super) fn filter( let get = warp::get() .and(with_templates(templates)) .and(updated_csrf_token(cookies_config, csrf_config)) + .and(warp::query()) + .and(with_optional_session(pool, cookies_config)) .and_then(get) .untuple_one() .with(wrap_fn(save_csrf_token(cookies_config))); let post = warp::post() - .and(with_pool(pool)) + .and(with_connection(pool)) .and(protected_form(cookies_config)) .and(warp::query()) .and_then(post) @@ -86,36 +105,26 @@ pub(super) fn filter( async fn get( templates: Templates, csrf_token: CsrfToken, -) -> Result<(CsrfToken, impl Reply), Rejection> { - let ctx = ().with_csrf(&csrf_token); - - // TODO: check if there is an existing session - let content = templates.render_login(&ctx)?; - Ok((csrf_token, html(content))) + query: LoginRequest, + maybe_session: Option, +) -> Result<(CsrfToken, Box), Rejection> { + if maybe_session.is_some() { + Ok((csrf_token, Box::new(query.redirect()?))) + } else { + let ctx = ().with_csrf(&csrf_token); + let content = templates.render_login(&ctx)?; + Ok((csrf_token, Box::new(html(content)))) + } } async fn post( - db: PgPool, + mut conn: PoolConnection, form: LoginForm, query: LoginRequest, ) -> Result<(SessionInfo, impl Reply), Rejection> { - let session_info = login(&db, &form.username, &form.password) + let session_info = login(&mut conn, &form.username, &form.password) .await .wrap_error()?; - let uri: Uri = Uri::from_parts({ - let mut parts = Parts::default(); - parts.path_and_query = Some( - query - .next - .map(warp::http::uri::PathAndQuery::try_from) - .transpose() - .wrap_error()? - .unwrap_or_else(|| PathAndQuery::from_static("/")), - ); - parts - }) - .wrap_error()?; - - Ok((session_info, warp::redirect(uri))) + Ok((session_info, query.redirect()?)) } diff --git a/matrix-authentication-service/src/handlers/views/logout.rs b/matrix-authentication-service/src/handlers/views/logout.rs index 14628366e..76f419966 100644 --- a/matrix-authentication-service/src/handlers/views/logout.rs +++ b/matrix-authentication-service/src/handlers/views/logout.rs @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use sqlx::PgPool; +use sqlx::{pool::PoolConnection, PgPool, Postgres}; use warp::{hyper::Uri, Filter, Rejection, Reply}; use crate::{ config::CookiesConfig, errors::WrapError, - filters::{csrf::protected_form, session::with_session, with_pool}, + filters::{csrf::protected_form, database::with_connection, session::with_session}, storage::SessionInfo, }; @@ -29,12 +29,16 @@ pub(super) fn filter( warp::post() .and(warp::path("logout")) .and(with_session(pool, cookies_config)) - .and(with_pool(pool)) + .and(with_connection(pool)) .and(protected_form(cookies_config)) .and_then(post) } -async fn post(session: SessionInfo, pool: PgPool, _form: ()) -> Result { - session.end(&pool).await.wrap_error()?; +async fn post( + session: SessionInfo, + mut conn: PoolConnection, + _form: (), +) -> Result { + session.end(&mut conn).await.wrap_error()?; Ok::<_, Rejection>(warp::redirect(Uri::from_static("/login"))) } diff --git a/matrix-authentication-service/src/handlers/views/reauth.rs b/matrix-authentication-service/src/handlers/views/reauth.rs index d087eb231..bd1555d69 100644 --- a/matrix-authentication-service/src/handlers/views/reauth.rs +++ b/matrix-authentication-service/src/handlers/views/reauth.rs @@ -13,7 +13,7 @@ // limitations under the License. use serde::Deserialize; -use sqlx::PgPool; +use sqlx::{pool::PoolConnection, PgPool, Postgres}; use warp::{hyper::Uri, reply::html, wrap_fn, Filter, Rejection, Reply}; use crate::{ @@ -21,8 +21,9 @@ use crate::{ errors::WrapError, filters::{ csrf::{protected_form, save_csrf_token, updated_csrf_token}, + database::with_connection, session::with_session, - with_pool, with_templates, CsrfToken, + with_templates, CsrfToken, }, storage::SessionInfo, templates::{TemplateContext, Templates}, @@ -49,7 +50,7 @@ pub(super) fn filter( let post = warp::post() .and(with_session(pool, cookies_config)) - .and(with_pool(pool)) + .and(with_connection(pool)) .and(protected_form(cookies_config)) .and_then(post); @@ -69,10 +70,13 @@ async fn get( async fn post( session: SessionInfo, - pool: PgPool, + mut conn: PoolConnection, form: ReauthForm, ) -> Result { - let _session = session.reauth(&pool, &form.password).await.wrap_error()?; + let _session = session + .reauth(&mut conn, &form.password) + .await + .wrap_error()?; Ok(warp::redirect(Uri::from_static("/"))) } diff --git a/matrix-authentication-service/src/storage/user.rs b/matrix-authentication-service/src/storage/user.rs index ef5dd24be..1a06dd8bd 100644 --- a/matrix-authentication-service/src/storage/user.rs +++ b/matrix-authentication-service/src/storage/user.rs @@ -20,7 +20,7 @@ use chrono::{DateTime, Utc}; use password_hash::{PasswordHash, PasswordHasher, SaltString}; use rand::rngs::OsRng; use serde::Serialize; -use sqlx::{Executor, FromRow, PgPool, Postgres, Transaction}; +use sqlx::{Acquire, Executor, FromRow, Postgres, Transaction}; use tracing::{info_span, Instrument}; #[derive(Serialize, Debug, Clone, FromRow)] @@ -44,8 +44,12 @@ impl SessionInfo { self.id } - pub async fn reauth(mut self, pool: &PgPool, password: &str) -> anyhow::Result { - let mut txn = pool.begin().await?; + pub async fn reauth( + mut self, + conn: impl Acquire<'_, Database = Postgres>, + password: &str, + ) -> anyhow::Result { + let mut txn = conn.begin().await?; self.last_authd_at = Some(authenticate_session(&mut txn, self.id, password).await?); txn.commit().await?; Ok(self) @@ -61,8 +65,12 @@ impl SessionInfo { } } -pub async fn login(pool: &PgPool, username: &str, password: &str) -> anyhow::Result { - let mut txn = pool.begin().await?; +pub async fn login( + conn: impl Acquire<'_, Database = Postgres>, + username: &str, + password: &str, +) -> anyhow::Result { + let mut txn = conn.begin().await?; let user = lookup_user_by_username(&mut txn, username).await?; let mut session = start_session(&mut txn, user).await?; session.last_authd_at = Some(authenticate_session(&mut txn, session.id, password).await?);