Acquire DB conns and txns on filter level

This avoids having the pool everywhere and instead have connections and
transactions as parameters
This commit is contained in:
Quentin Gliech
2021-08-13 09:38:41 +02:00
parent 4eb1b5d4f8
commit da13e24789
9 changed files with 159 additions and 86 deletions
@@ -0,0 +1,54 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::convert::Infallible;
use sqlx::{pool::PoolConnection, PgPool, Postgres, Transaction};
use warp::{Filter, Rejection};
use crate::errors::WrapError;
fn with_pool(
pool: &PgPool,
) -> impl Filter<Extract = (PgPool,), Error = Infallible> + Clone + Send + Sync + 'static {
let pool = pool.clone();
warp::any().map(move || pool.clone())
}
pub fn with_connection(
pool: &PgPool,
) -> impl Filter<Extract = (PoolConnection<Postgres>,), Error = Rejection> + Clone + Send + Sync + 'static
{
with_pool(pool).and_then(acquire_connection)
}
async fn acquire_connection(pool: PgPool) -> Result<PoolConnection<Postgres>, Rejection> {
let conn = pool.acquire().await.wrap_error()?;
Ok(conn)
}
pub fn with_transaction(
pool: &PgPool,
) -> impl Filter<Extract = (Transaction<'static, Postgres>,), Error = Rejection>
+ Clone
+ Send
+ Sync
+ 'static {
with_pool(pool).and_then(acquire_transaction)
}
async fn acquire_transaction(pool: PgPool) -> Result<Transaction<'static, Postgres>, Rejection> {
let txn = pool.begin().await.wrap_error()?;
Ok(txn)
}
@@ -15,23 +15,16 @@
pub mod csrf;
// mod errors;
pub mod cookies;
pub mod database;
pub mod session;
use std::convert::Infallible;
use sqlx::PgPool;
use warp::Filter;
pub use self::csrf::CsrfToken;
use crate::templates::Templates;
pub fn with_pool(
pool: &PgPool,
) -> impl Filter<Extract = (PgPool,), Error = Infallible> + Clone + Send + Sync + 'static {
let pool = pool.clone();
warp::any().map(move || pool.clone())
}
pub fn with_templates(
templates: &Templates,
) -> impl Filter<Extract = (Templates,), Error = Infallible> + Clone + Send + Sync + 'static {
@@ -14,12 +14,12 @@
use headers::SetCookie;
use serde::{Deserialize, Serialize};
use sqlx::{Executor, PgPool, Postgres};
use sqlx::{pool::PoolConnection, Executor, PgPool, Postgres};
use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
use super::{
cookies::{encrypted, maybe_encrypted, save_encrypted, WithTypedHeader},
with_pool,
database::with_connection,
};
use crate::{
config::CookiesConfig,
@@ -53,15 +53,17 @@ pub fn with_optional_session(
) -> impl Filter<Extract = (Option<SessionInfo>,), Error = Rejection> + Clone + Send + Sync + 'static
{
maybe_encrypted("session", cookies_config)
.and(with_pool(pool))
.and_then(|maybe_session: Option<Session>, pool: PgPool| async move {
let maybe_session_info = if let Some(session) = maybe_session {
session.load_session_info(&pool).await.ok()
} else {
None
};
Ok::<_, Rejection>(maybe_session_info)
})
.and(with_connection(pool))
.and_then(
|maybe_session: Option<Session>, mut conn: PoolConnection<Postgres>| async move {
let maybe_session_info = if let Some(session) = maybe_session {
session.load_session_info(&mut conn).await.ok()
} else {
None
};
Ok::<_, Rejection>(maybe_session_info)
},
)
}
pub fn with_session(
@@ -69,11 +71,13 @@ pub fn with_session(
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (SessionInfo,), Error = Rejection> + Clone + Send + Sync + 'static {
encrypted("session", cookies_config)
.and(with_pool(pool))
.and_then(|session: Session, pool: PgPool| async move {
let session_info = session.load_session_info(&pool).await.wrap_error()?;
Ok::<_, Rejection>(session_info)
})
.and(with_connection(pool))
.and_then(
|session: Session, mut conn: PoolConnection<Postgres>| async move {
let session_info = session.load_session_info(&mut conn).await.wrap_error()?;
Ok::<_, Rejection>(session_info)
},
)
}
pub fn save_session<R: Reply, F>(
@@ -14,25 +14,25 @@
use hyper::header::CONTENT_TYPE;
use mime::TEXT_PLAIN;
use sqlx::PgPool;
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use tracing::{info_span, Instrument};
use warp::{reply::with_header, Filter, Rejection, Reply};
use crate::{errors::WrapError, filters::with_pool};
use crate::{errors::WrapError, filters::database::with_connection};
pub fn filter(
pool: &PgPool,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
warp::get()
.and(warp::path("health"))
.and(with_pool(pool))
.and(with_connection(pool))
.and_then(get)
}
async fn get(pool: PgPool) -> Result<impl Reply, Rejection> {
async fn get(mut conn: PoolConnection<Postgres>) -> Result<impl Reply, Rejection> {
sqlx::query("SELECT $1")
.bind(1_i64)
.execute(&pool)
.execute(&mut conn)
.instrument(info_span!("DB health"))
.await
.wrap_error()?;
@@ -32,7 +32,7 @@ use oauth2_types::{
},
};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use sqlx::{PgPool, Postgres, Transaction};
use url::Url;
use warp::{
redirect::see_other,
@@ -44,8 +44,9 @@ use crate::{
config::{CookiesConfig, OAuth2ClientConfig, OAuth2Config},
errors::WrapError,
filters::{
database::with_transaction,
session::{with_optional_session, with_session},
with_pool, with_templates,
with_templates,
},
handlers::views::LoginRequest,
storage::{
@@ -164,7 +165,7 @@ pub fn filter(
.map(move || clients.clone())
.and(warp::query())
.and(with_optional_session(pool, cookies_config))
.and(with_pool(pool))
.and(with_transaction(pool))
.and(with_templates(templates))
.and_then(get);
@@ -172,7 +173,7 @@ pub fn filter(
.and(warp::path!("oauth2" / "authorize" / "step"))
.and(warp::query().map(|s: StepRequest| s.id))
.and(with_session(pool, cookies_config))
.and(with_pool(pool))
.and(with_transaction(pool))
.and(with_templates(templates))
.and_then(step);
@@ -183,7 +184,7 @@ async fn get(
clients: Vec<OAuth2ClientConfig>,
params: Params,
maybe_session: Option<SessionInfo>,
pool: PgPool,
mut txn: Transaction<'_, Postgres>,
templates: Templates,
) -> Result<Box<dyn Reply>, Rejection> {
// First, find out what client it is
@@ -198,8 +199,6 @@ async fn get(
.resolve_redirect_uri(&params.auth.redirect_uri)
.wrap_error()?;
// Start a DB transaction
let mut txn = pool.begin().await.wrap_error()?;
let maybe_session_id = maybe_session.as_ref().map(SessionInfo::key);
let scope: String = {
@@ -237,14 +236,15 @@ async fn get(
.wrap_error()?;
};
// Do we have a user in this session, with a last authentication time that
// matches the requirement?
// Do we already have a user session for this oauth2 session?
let user_session = oauth2_session.fetch_session(&mut txn).await.wrap_error()?;
txn.commit().await.wrap_error()?;
if let Some(user_session) = user_session {
step(oauth2_session.id, user_session, pool, templates).await
step(oauth2_session.id, user_session, txn, templates).await
} else {
// If not, redirect the user to the login page
txn.commit().await.wrap_error()?;
let next = StepRequest::new(oauth2_session.id)
.build_uri()
.wrap_error()?
@@ -280,12 +280,9 @@ impl StepRequest {
async fn step(
oauth2_session_id: i64,
user_session: SessionInfo,
pool: PgPool,
mut txn: Transaction<'_, Postgres>,
templates: Templates,
) -> Result<Box<dyn Reply>, Rejection> {
// Start a DB transaction
let mut txn = pool.begin().await.wrap_error()?;
let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id)
.await
.wrap_error()?;
@@ -16,7 +16,7 @@ use std::convert::TryFrom;
use hyper::http::uri::{Parts, PathAndQuery, Uri};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use warp::{reply::html, wrap_fn, Filter, Rejection, Reply};
use crate::{
@@ -24,8 +24,9 @@ use crate::{
errors::WrapError,
filters::{
csrf::{protected_form, save_csrf_token, updated_csrf_token},
session::save_session,
with_pool, with_templates, CsrfToken,
database::with_connection,
session::{save_session, with_optional_session},
with_templates, CsrfToken,
},
storage::{login, SessionInfo},
templates::{TemplateContext, Templates},
@@ -51,6 +52,22 @@ impl LoginRequest {
})?;
Ok(uri)
}
fn redirect(self) -> Result<impl Reply, Rejection> {
let uri: Uri = Uri::from_parts({
let mut parts = Parts::default();
parts.path_and_query = Some(
self.next
.map(warp::http::uri::PathAndQuery::try_from)
.transpose()
.wrap_error()?
.unwrap_or_else(|| PathAndQuery::from_static("/")),
);
parts
})
.wrap_error()?;
Ok(warp::redirect::see_other(uri))
}
}
#[derive(Deserialize)]
@@ -68,12 +85,14 @@ pub(super) fn filter(
let get = warp::get()
.and(with_templates(templates))
.and(updated_csrf_token(cookies_config, csrf_config))
.and(warp::query())
.and(with_optional_session(pool, cookies_config))
.and_then(get)
.untuple_one()
.with(wrap_fn(save_csrf_token(cookies_config)));
let post = warp::post()
.and(with_pool(pool))
.and(with_connection(pool))
.and(protected_form(cookies_config))
.and(warp::query())
.and_then(post)
@@ -86,36 +105,26 @@ pub(super) fn filter(
async fn get(
templates: Templates,
csrf_token: CsrfToken,
) -> Result<(CsrfToken, impl Reply), Rejection> {
let ctx = ().with_csrf(&csrf_token);
// TODO: check if there is an existing session
let content = templates.render_login(&ctx)?;
Ok((csrf_token, html(content)))
query: LoginRequest,
maybe_session: Option<SessionInfo>,
) -> Result<(CsrfToken, Box<dyn Reply>), Rejection> {
if maybe_session.is_some() {
Ok((csrf_token, Box::new(query.redirect()?)))
} else {
let ctx = ().with_csrf(&csrf_token);
let content = templates.render_login(&ctx)?;
Ok((csrf_token, Box::new(html(content))))
}
}
async fn post(
db: PgPool,
mut conn: PoolConnection<Postgres>,
form: LoginForm,
query: LoginRequest,
) -> Result<(SessionInfo, impl Reply), Rejection> {
let session_info = login(&db, &form.username, &form.password)
let session_info = login(&mut conn, &form.username, &form.password)
.await
.wrap_error()?;
let uri: Uri = Uri::from_parts({
let mut parts = Parts::default();
parts.path_and_query = Some(
query
.next
.map(warp::http::uri::PathAndQuery::try_from)
.transpose()
.wrap_error()?
.unwrap_or_else(|| PathAndQuery::from_static("/")),
);
parts
})
.wrap_error()?;
Ok((session_info, warp::redirect(uri)))
Ok((session_info, query.redirect()?))
}
@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use sqlx::PgPool;
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use warp::{hyper::Uri, Filter, Rejection, Reply};
use crate::{
config::CookiesConfig,
errors::WrapError,
filters::{csrf::protected_form, session::with_session, with_pool},
filters::{csrf::protected_form, database::with_connection, session::with_session},
storage::SessionInfo,
};
@@ -29,12 +29,16 @@ pub(super) fn filter(
warp::post()
.and(warp::path("logout"))
.and(with_session(pool, cookies_config))
.and(with_pool(pool))
.and(with_connection(pool))
.and(protected_form(cookies_config))
.and_then(post)
}
async fn post(session: SessionInfo, pool: PgPool, _form: ()) -> Result<impl Reply, Rejection> {
session.end(&pool).await.wrap_error()?;
async fn post(
session: SessionInfo,
mut conn: PoolConnection<Postgres>,
_form: (),
) -> Result<impl Reply, Rejection> {
session.end(&mut conn).await.wrap_error()?;
Ok::<_, Rejection>(warp::redirect(Uri::from_static("/login")))
}
@@ -13,7 +13,7 @@
// limitations under the License.
use serde::Deserialize;
use sqlx::PgPool;
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use warp::{hyper::Uri, reply::html, wrap_fn, Filter, Rejection, Reply};
use crate::{
@@ -21,8 +21,9 @@ use crate::{
errors::WrapError,
filters::{
csrf::{protected_form, save_csrf_token, updated_csrf_token},
database::with_connection,
session::with_session,
with_pool, with_templates, CsrfToken,
with_templates, CsrfToken,
},
storage::SessionInfo,
templates::{TemplateContext, Templates},
@@ -49,7 +50,7 @@ pub(super) fn filter(
let post = warp::post()
.and(with_session(pool, cookies_config))
.and(with_pool(pool))
.and(with_connection(pool))
.and(protected_form(cookies_config))
.and_then(post);
@@ -69,10 +70,13 @@ async fn get(
async fn post(
session: SessionInfo,
pool: PgPool,
mut conn: PoolConnection<Postgres>,
form: ReauthForm,
) -> Result<impl Reply, Rejection> {
let _session = session.reauth(&pool, &form.password).await.wrap_error()?;
let _session = session
.reauth(&mut conn, &form.password)
.await
.wrap_error()?;
Ok(warp::redirect(Uri::from_static("/")))
}
@@ -20,7 +20,7 @@ use chrono::{DateTime, Utc};
use password_hash::{PasswordHash, PasswordHasher, SaltString};
use rand::rngs::OsRng;
use serde::Serialize;
use sqlx::{Executor, FromRow, PgPool, Postgres, Transaction};
use sqlx::{Acquire, Executor, FromRow, Postgres, Transaction};
use tracing::{info_span, Instrument};
#[derive(Serialize, Debug, Clone, FromRow)]
@@ -44,8 +44,12 @@ impl SessionInfo {
self.id
}
pub async fn reauth(mut self, pool: &PgPool, password: &str) -> anyhow::Result<Self> {
let mut txn = pool.begin().await?;
pub async fn reauth(
mut self,
conn: impl Acquire<'_, Database = Postgres>,
password: &str,
) -> anyhow::Result<Self> {
let mut txn = conn.begin().await?;
self.last_authd_at = Some(authenticate_session(&mut txn, self.id, password).await?);
txn.commit().await?;
Ok(self)
@@ -61,8 +65,12 @@ impl SessionInfo {
}
}
pub async fn login(pool: &PgPool, username: &str, password: &str) -> anyhow::Result<SessionInfo> {
let mut txn = pool.begin().await?;
pub async fn login(
conn: impl Acquire<'_, Database = Postgres>,
username: &str,
password: &str,
) -> anyhow::Result<SessionInfo> {
let mut txn = conn.begin().await?;
let user = lookup_user_by_username(&mut txn, username).await?;
let mut session = start_session(&mut txn, user).await?;
session.last_authd_at = Some(authenticate_session(&mut txn, session.id, password).await?);