mirror of
https://github.com/element-hq/matrix-authentication-service.git
synced 2026-05-24 04:25:38 +00:00
Acquire DB conns and txns on filter level
This avoids having the pool everywhere and instead have connections and transactions as parameters
This commit is contained in:
@@ -0,0 +1,54 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres, Transaction};
|
||||
use warp::{Filter, Rejection};
|
||||
|
||||
use crate::errors::WrapError;
|
||||
|
||||
fn with_pool(
|
||||
pool: &PgPool,
|
||||
) -> impl Filter<Extract = (PgPool,), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
let pool = pool.clone();
|
||||
warp::any().map(move || pool.clone())
|
||||
}
|
||||
|
||||
pub fn with_connection(
|
||||
pool: &PgPool,
|
||||
) -> impl Filter<Extract = (PoolConnection<Postgres>,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
{
|
||||
with_pool(pool).and_then(acquire_connection)
|
||||
}
|
||||
|
||||
async fn acquire_connection(pool: PgPool) -> Result<PoolConnection<Postgres>, Rejection> {
|
||||
let conn = pool.acquire().await.wrap_error()?;
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
pub fn with_transaction(
|
||||
pool: &PgPool,
|
||||
) -> impl Filter<Extract = (Transaction<'static, Postgres>,), Error = Rejection>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static {
|
||||
with_pool(pool).and_then(acquire_transaction)
|
||||
}
|
||||
|
||||
async fn acquire_transaction(pool: PgPool) -> Result<Transaction<'static, Postgres>, Rejection> {
|
||||
let txn = pool.begin().await.wrap_error()?;
|
||||
Ok(txn)
|
||||
}
|
||||
@@ -15,23 +15,16 @@
|
||||
pub mod csrf;
|
||||
// mod errors;
|
||||
pub mod cookies;
|
||||
pub mod database;
|
||||
pub mod session;
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use sqlx::PgPool;
|
||||
use warp::Filter;
|
||||
|
||||
pub use self::csrf::CsrfToken;
|
||||
use crate::templates::Templates;
|
||||
|
||||
pub fn with_pool(
|
||||
pool: &PgPool,
|
||||
) -> impl Filter<Extract = (PgPool,), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
let pool = pool.clone();
|
||||
warp::any().map(move || pool.clone())
|
||||
}
|
||||
|
||||
pub fn with_templates(
|
||||
templates: &Templates,
|
||||
) -> impl Filter<Extract = (Templates,), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
|
||||
@@ -14,12 +14,12 @@
|
||||
|
||||
use headers::SetCookie;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Executor, PgPool, Postgres};
|
||||
use sqlx::{pool::PoolConnection, Executor, PgPool, Postgres};
|
||||
use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
|
||||
|
||||
use super::{
|
||||
cookies::{encrypted, maybe_encrypted, save_encrypted, WithTypedHeader},
|
||||
with_pool,
|
||||
database::with_connection,
|
||||
};
|
||||
use crate::{
|
||||
config::CookiesConfig,
|
||||
@@ -53,15 +53,17 @@ pub fn with_optional_session(
|
||||
) -> impl Filter<Extract = (Option<SessionInfo>,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
{
|
||||
maybe_encrypted("session", cookies_config)
|
||||
.and(with_pool(pool))
|
||||
.and_then(|maybe_session: Option<Session>, pool: PgPool| async move {
|
||||
let maybe_session_info = if let Some(session) = maybe_session {
|
||||
session.load_session_info(&pool).await.ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok::<_, Rejection>(maybe_session_info)
|
||||
})
|
||||
.and(with_connection(pool))
|
||||
.and_then(
|
||||
|maybe_session: Option<Session>, mut conn: PoolConnection<Postgres>| async move {
|
||||
let maybe_session_info = if let Some(session) = maybe_session {
|
||||
session.load_session_info(&mut conn).await.ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok::<_, Rejection>(maybe_session_info)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn with_session(
|
||||
@@ -69,11 +71,13 @@ pub fn with_session(
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (SessionInfo,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
encrypted("session", cookies_config)
|
||||
.and(with_pool(pool))
|
||||
.and_then(|session: Session, pool: PgPool| async move {
|
||||
let session_info = session.load_session_info(&pool).await.wrap_error()?;
|
||||
Ok::<_, Rejection>(session_info)
|
||||
})
|
||||
.and(with_connection(pool))
|
||||
.and_then(
|
||||
|session: Session, mut conn: PoolConnection<Postgres>| async move {
|
||||
let session_info = session.load_session_info(&mut conn).await.wrap_error()?;
|
||||
Ok::<_, Rejection>(session_info)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn save_session<R: Reply, F>(
|
||||
|
||||
@@ -14,25 +14,25 @@
|
||||
|
||||
use hyper::header::CONTENT_TYPE;
|
||||
use mime::TEXT_PLAIN;
|
||||
use sqlx::PgPool;
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use tracing::{info_span, Instrument};
|
||||
use warp::{reply::with_header, Filter, Rejection, Reply};
|
||||
|
||||
use crate::{errors::WrapError, filters::with_pool};
|
||||
use crate::{errors::WrapError, filters::database::with_connection};
|
||||
|
||||
pub fn filter(
|
||||
pool: &PgPool,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
warp::get()
|
||||
.and(warp::path("health"))
|
||||
.and(with_pool(pool))
|
||||
.and(with_connection(pool))
|
||||
.and_then(get)
|
||||
}
|
||||
|
||||
async fn get(pool: PgPool) -> Result<impl Reply, Rejection> {
|
||||
async fn get(mut conn: PoolConnection<Postgres>) -> Result<impl Reply, Rejection> {
|
||||
sqlx::query("SELECT $1")
|
||||
.bind(1_i64)
|
||||
.execute(&pool)
|
||||
.execute(&mut conn)
|
||||
.instrument(info_span!("DB health"))
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
@@ -32,7 +32,7 @@ use oauth2_types::{
|
||||
},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use sqlx::{PgPool, Postgres, Transaction};
|
||||
use url::Url;
|
||||
use warp::{
|
||||
redirect::see_other,
|
||||
@@ -44,8 +44,9 @@ use crate::{
|
||||
config::{CookiesConfig, OAuth2ClientConfig, OAuth2Config},
|
||||
errors::WrapError,
|
||||
filters::{
|
||||
database::with_transaction,
|
||||
session::{with_optional_session, with_session},
|
||||
with_pool, with_templates,
|
||||
with_templates,
|
||||
},
|
||||
handlers::views::LoginRequest,
|
||||
storage::{
|
||||
@@ -164,7 +165,7 @@ pub fn filter(
|
||||
.map(move || clients.clone())
|
||||
.and(warp::query())
|
||||
.and(with_optional_session(pool, cookies_config))
|
||||
.and(with_pool(pool))
|
||||
.and(with_transaction(pool))
|
||||
.and(with_templates(templates))
|
||||
.and_then(get);
|
||||
|
||||
@@ -172,7 +173,7 @@ pub fn filter(
|
||||
.and(warp::path!("oauth2" / "authorize" / "step"))
|
||||
.and(warp::query().map(|s: StepRequest| s.id))
|
||||
.and(with_session(pool, cookies_config))
|
||||
.and(with_pool(pool))
|
||||
.and(with_transaction(pool))
|
||||
.and(with_templates(templates))
|
||||
.and_then(step);
|
||||
|
||||
@@ -183,7 +184,7 @@ async fn get(
|
||||
clients: Vec<OAuth2ClientConfig>,
|
||||
params: Params,
|
||||
maybe_session: Option<SessionInfo>,
|
||||
pool: PgPool,
|
||||
mut txn: Transaction<'_, Postgres>,
|
||||
templates: Templates,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
// First, find out what client it is
|
||||
@@ -198,8 +199,6 @@ async fn get(
|
||||
.resolve_redirect_uri(¶ms.auth.redirect_uri)
|
||||
.wrap_error()?;
|
||||
|
||||
// Start a DB transaction
|
||||
let mut txn = pool.begin().await.wrap_error()?;
|
||||
let maybe_session_id = maybe_session.as_ref().map(SessionInfo::key);
|
||||
|
||||
let scope: String = {
|
||||
@@ -237,14 +236,15 @@ async fn get(
|
||||
.wrap_error()?;
|
||||
};
|
||||
|
||||
// Do we have a user in this session, with a last authentication time that
|
||||
// matches the requirement?
|
||||
// Do we already have a user session for this oauth2 session?
|
||||
let user_session = oauth2_session.fetch_session(&mut txn).await.wrap_error()?;
|
||||
txn.commit().await.wrap_error()?;
|
||||
|
||||
if let Some(user_session) = user_session {
|
||||
step(oauth2_session.id, user_session, pool, templates).await
|
||||
step(oauth2_session.id, user_session, txn, templates).await
|
||||
} else {
|
||||
// If not, redirect the user to the login page
|
||||
txn.commit().await.wrap_error()?;
|
||||
|
||||
let next = StepRequest::new(oauth2_session.id)
|
||||
.build_uri()
|
||||
.wrap_error()?
|
||||
@@ -280,12 +280,9 @@ impl StepRequest {
|
||||
async fn step(
|
||||
oauth2_session_id: i64,
|
||||
user_session: SessionInfo,
|
||||
pool: PgPool,
|
||||
mut txn: Transaction<'_, Postgres>,
|
||||
templates: Templates,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
// Start a DB transaction
|
||||
let mut txn = pool.begin().await.wrap_error()?;
|
||||
|
||||
let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
@@ -16,7 +16,7 @@ use std::convert::TryFrom;
|
||||
|
||||
use hyper::http::uri::{Parts, PathAndQuery, Uri};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use warp::{reply::html, wrap_fn, Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
@@ -24,8 +24,9 @@ use crate::{
|
||||
errors::WrapError,
|
||||
filters::{
|
||||
csrf::{protected_form, save_csrf_token, updated_csrf_token},
|
||||
session::save_session,
|
||||
with_pool, with_templates, CsrfToken,
|
||||
database::with_connection,
|
||||
session::{save_session, with_optional_session},
|
||||
with_templates, CsrfToken,
|
||||
},
|
||||
storage::{login, SessionInfo},
|
||||
templates::{TemplateContext, Templates},
|
||||
@@ -51,6 +52,22 @@ impl LoginRequest {
|
||||
})?;
|
||||
Ok(uri)
|
||||
}
|
||||
|
||||
fn redirect(self) -> Result<impl Reply, Rejection> {
|
||||
let uri: Uri = Uri::from_parts({
|
||||
let mut parts = Parts::default();
|
||||
parts.path_and_query = Some(
|
||||
self.next
|
||||
.map(warp::http::uri::PathAndQuery::try_from)
|
||||
.transpose()
|
||||
.wrap_error()?
|
||||
.unwrap_or_else(|| PathAndQuery::from_static("/")),
|
||||
);
|
||||
parts
|
||||
})
|
||||
.wrap_error()?;
|
||||
Ok(warp::redirect::see_other(uri))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -68,12 +85,14 @@ pub(super) fn filter(
|
||||
let get = warp::get()
|
||||
.and(with_templates(templates))
|
||||
.and(updated_csrf_token(cookies_config, csrf_config))
|
||||
.and(warp::query())
|
||||
.and(with_optional_session(pool, cookies_config))
|
||||
.and_then(get)
|
||||
.untuple_one()
|
||||
.with(wrap_fn(save_csrf_token(cookies_config)));
|
||||
|
||||
let post = warp::post()
|
||||
.and(with_pool(pool))
|
||||
.and(with_connection(pool))
|
||||
.and(protected_form(cookies_config))
|
||||
.and(warp::query())
|
||||
.and_then(post)
|
||||
@@ -86,36 +105,26 @@ pub(super) fn filter(
|
||||
async fn get(
|
||||
templates: Templates,
|
||||
csrf_token: CsrfToken,
|
||||
) -> Result<(CsrfToken, impl Reply), Rejection> {
|
||||
let ctx = ().with_csrf(&csrf_token);
|
||||
|
||||
// TODO: check if there is an existing session
|
||||
let content = templates.render_login(&ctx)?;
|
||||
Ok((csrf_token, html(content)))
|
||||
query: LoginRequest,
|
||||
maybe_session: Option<SessionInfo>,
|
||||
) -> Result<(CsrfToken, Box<dyn Reply>), Rejection> {
|
||||
if maybe_session.is_some() {
|
||||
Ok((csrf_token, Box::new(query.redirect()?)))
|
||||
} else {
|
||||
let ctx = ().with_csrf(&csrf_token);
|
||||
let content = templates.render_login(&ctx)?;
|
||||
Ok((csrf_token, Box::new(html(content))))
|
||||
}
|
||||
}
|
||||
|
||||
async fn post(
|
||||
db: PgPool,
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
form: LoginForm,
|
||||
query: LoginRequest,
|
||||
) -> Result<(SessionInfo, impl Reply), Rejection> {
|
||||
let session_info = login(&db, &form.username, &form.password)
|
||||
let session_info = login(&mut conn, &form.username, &form.password)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
let uri: Uri = Uri::from_parts({
|
||||
let mut parts = Parts::default();
|
||||
parts.path_and_query = Some(
|
||||
query
|
||||
.next
|
||||
.map(warp::http::uri::PathAndQuery::try_from)
|
||||
.transpose()
|
||||
.wrap_error()?
|
||||
.unwrap_or_else(|| PathAndQuery::from_static("/")),
|
||||
);
|
||||
parts
|
||||
})
|
||||
.wrap_error()?;
|
||||
|
||||
Ok((session_info, warp::redirect(uri)))
|
||||
Ok((session_info, query.redirect()?))
|
||||
}
|
||||
|
||||
@@ -12,13 +12,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use sqlx::PgPool;
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use warp::{hyper::Uri, Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::CookiesConfig,
|
||||
errors::WrapError,
|
||||
filters::{csrf::protected_form, session::with_session, with_pool},
|
||||
filters::{csrf::protected_form, database::with_connection, session::with_session},
|
||||
storage::SessionInfo,
|
||||
};
|
||||
|
||||
@@ -29,12 +29,16 @@ pub(super) fn filter(
|
||||
warp::post()
|
||||
.and(warp::path("logout"))
|
||||
.and(with_session(pool, cookies_config))
|
||||
.and(with_pool(pool))
|
||||
.and(with_connection(pool))
|
||||
.and(protected_form(cookies_config))
|
||||
.and_then(post)
|
||||
}
|
||||
|
||||
async fn post(session: SessionInfo, pool: PgPool, _form: ()) -> Result<impl Reply, Rejection> {
|
||||
session.end(&pool).await.wrap_error()?;
|
||||
async fn post(
|
||||
session: SessionInfo,
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
_form: (),
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
session.end(&mut conn).await.wrap_error()?;
|
||||
Ok::<_, Rejection>(warp::redirect(Uri::from_static("/login")))
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use warp::{hyper::Uri, reply::html, wrap_fn, Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
@@ -21,8 +21,9 @@ use crate::{
|
||||
errors::WrapError,
|
||||
filters::{
|
||||
csrf::{protected_form, save_csrf_token, updated_csrf_token},
|
||||
database::with_connection,
|
||||
session::with_session,
|
||||
with_pool, with_templates, CsrfToken,
|
||||
with_templates, CsrfToken,
|
||||
},
|
||||
storage::SessionInfo,
|
||||
templates::{TemplateContext, Templates},
|
||||
@@ -49,7 +50,7 @@ pub(super) fn filter(
|
||||
|
||||
let post = warp::post()
|
||||
.and(with_session(pool, cookies_config))
|
||||
.and(with_pool(pool))
|
||||
.and(with_connection(pool))
|
||||
.and(protected_form(cookies_config))
|
||||
.and_then(post);
|
||||
|
||||
@@ -69,10 +70,13 @@ async fn get(
|
||||
|
||||
async fn post(
|
||||
session: SessionInfo,
|
||||
pool: PgPool,
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
form: ReauthForm,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let _session = session.reauth(&pool, &form.password).await.wrap_error()?;
|
||||
let _session = session
|
||||
.reauth(&mut conn, &form.password)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
Ok(warp::redirect(Uri::from_static("/")))
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ use chrono::{DateTime, Utc};
|
||||
use password_hash::{PasswordHash, PasswordHasher, SaltString};
|
||||
use rand::rngs::OsRng;
|
||||
use serde::Serialize;
|
||||
use sqlx::{Executor, FromRow, PgPool, Postgres, Transaction};
|
||||
use sqlx::{Acquire, Executor, FromRow, Postgres, Transaction};
|
||||
use tracing::{info_span, Instrument};
|
||||
|
||||
#[derive(Serialize, Debug, Clone, FromRow)]
|
||||
@@ -44,8 +44,12 @@ impl SessionInfo {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub async fn reauth(mut self, pool: &PgPool, password: &str) -> anyhow::Result<Self> {
|
||||
let mut txn = pool.begin().await?;
|
||||
pub async fn reauth(
|
||||
mut self,
|
||||
conn: impl Acquire<'_, Database = Postgres>,
|
||||
password: &str,
|
||||
) -> anyhow::Result<Self> {
|
||||
let mut txn = conn.begin().await?;
|
||||
self.last_authd_at = Some(authenticate_session(&mut txn, self.id, password).await?);
|
||||
txn.commit().await?;
|
||||
Ok(self)
|
||||
@@ -61,8 +65,12 @@ impl SessionInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn login(pool: &PgPool, username: &str, password: &str) -> anyhow::Result<SessionInfo> {
|
||||
let mut txn = pool.begin().await?;
|
||||
pub async fn login(
|
||||
conn: impl Acquire<'_, Database = Postgres>,
|
||||
username: &str,
|
||||
password: &str,
|
||||
) -> anyhow::Result<SessionInfo> {
|
||||
let mut txn = conn.begin().await?;
|
||||
let user = lookup_user_by_username(&mut txn, username).await?;
|
||||
let mut session = start_session(&mut txn, user).await?;
|
||||
session.last_authd_at = Some(authenticate_session(&mut txn, session.id, password).await?);
|
||||
|
||||
Reference in New Issue
Block a user