diff --git a/matrix-authentication-service/src/cli/server.rs b/matrix-authentication-service/src/cli/server.rs index c98cbe9eb..79eae172f 100644 --- a/matrix-authentication-service/src/cli/server.rs +++ b/matrix-authentication-service/src/cli/server.rs @@ -19,6 +19,7 @@ use clap::Clap; use super::RootCommand; use crate::config::RootConfig; +use crate::templates::Templates; #[derive(Clap, Debug, Default)] pub(super) struct ServerCommand; @@ -31,7 +32,7 @@ impl ServerCommand { let pool = config.database.connect().await?; // Load and compile the templates - let templates = crate::templates::load().context("could not load templates")?; + let templates = Templates::load().context("could not load templates")?; // Start the server let address: SocketAddr = config.http.address.parse()?; diff --git a/matrix-authentication-service/src/config/csrf.rs b/matrix-authentication-service/src/config/csrf.rs index b6b0840b9..52c6efa5d 100644 --- a/matrix-authentication-service/src/config/csrf.rs +++ b/matrix-authentication-service/src/config/csrf.rs @@ -55,10 +55,10 @@ pub struct CsrfConfig { } impl CsrfConfig { - pub fn into_extract_filter(self) -> BoxedFilter<(CsrfToken,)> { + pub fn to_extract_filter(&self) -> BoxedFilter<(CsrfToken,)> { let ttl = self.ttl; // TODO: we should probably not leak here - let cookie_name = Box::leak(Box::new(self.cookie_name)); + let cookie_name = Box::leak(Box::new(self.cookie_name.clone())); extract_or_generate(self.key, cookie_name, ttl) } } diff --git a/matrix-authentication-service/src/filters/mod.rs b/matrix-authentication-service/src/filters/mod.rs index 761d80efc..b71613853 100644 --- a/matrix-authentication-service/src/filters/mod.rs +++ b/matrix-authentication-service/src/filters/mod.rs @@ -15,4 +15,17 @@ pub mod csrf; // mod errors; -pub use csrf::UnencryptedToken as CsrfToken; +use sqlx::PgPool; +use warp::{filters::BoxedFilter, Filter}; + +use crate::templates::Templates; + +pub use self::csrf::UnencryptedToken as CsrfToken; + +pub fn with_pool(pool: PgPool) -> BoxedFilter<(PgPool,)> { + warp::any().map(move || pool.clone()).boxed() +} + +pub fn with_templates(templates: Templates) -> BoxedFilter<(Templates,)> { + warp::any().map(move || templates.clone()).boxed() +} diff --git a/matrix-authentication-service/src/handlers/health.rs b/matrix-authentication-service/src/handlers/health.rs index 3956d9a34..785705579 100644 --- a/matrix-authentication-service/src/handlers/health.rs +++ b/matrix-authentication-service/src/handlers/health.rs @@ -14,11 +14,19 @@ use sqlx::PgPool; use tracing::{info_span, Instrument}; -use warp::{Rejection, Reply}; +use warp::{filters::BoxedFilter, Filter, Rejection, Reply}; -use crate::errors::WrapError; +use crate::{errors::WrapError, filters::with_pool}; -pub async fn get(pool: PgPool) -> Result { +pub fn filter(pool: PgPool) -> BoxedFilter<(impl Reply,)> { + warp::get() + .and(warp::path("health")) + .and(with_pool(pool)) + .and_then(get) + .boxed() +} + +async fn get(pool: PgPool) -> Result { sqlx::query("SELECT $1") .bind(1_i64) .execute(&pool) diff --git a/matrix-authentication-service/src/handlers/mod.rs b/matrix-authentication-service/src/handlers/mod.rs index aae31d0eb..31baea2a8 100644 --- a/matrix-authentication-service/src/handlers/mod.rs +++ b/matrix-authentication-service/src/handlers/mod.rs @@ -12,94 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{convert::Infallible, sync::Arc}; - use sqlx::PgPool; -use tera::Tera; -use warp::{filters::BoxedFilter, wrap_fn, Filter, Rejection, Reply}; +use warp::{filters::BoxedFilter, Filter}; -use crate::{config::RootConfig, filters::csrf::with_csrf}; +use crate::{config::RootConfig, templates::Templates}; mod health; mod oauth2; mod views; -async fn display_error(err: Rejection) -> Result { - let ret = format!("{:?}", err); - Ok(ret) -} +use self::{health::filter as health, oauth2::filter as oauth2, views::filter as views}; pub fn root( pool: PgPool, - templates: Tera, + templates: Templates, config: &RootConfig, ) -> BoxedFilter<(impl warp::Reply,)> { - let templates = Arc::new(templates); - let with_csrf_token = config.csrf.clone().into_extract_filter(); - let with_pool = warp::any().map(move || pool.clone()); - let with_templates = warp::any().map(move || templates.clone()); - - // TODO: this is ugly and leaks - let csrf_cookie_name = Box::leak(Box::new(config.csrf.cookie_name.clone())); - - let cors = warp::cors().allow_any_origin(); - - let health = warp::path("health") - .and(warp::get()) - .and(with_pool.clone()) - .and_then(self::health::get) - .boxed(); - - let metadata = warp::path!(".well-known" / "openid-configuration") - .and(warp::get()) - .and(self::oauth2::discovery::get(&config.oauth2)) - .with(cors); - - let index = warp::path::end() - .and(warp::get()) - .and(with_templates.clone()) - .and(with_csrf_token.clone()) - .and(with_pool.clone()) - .and_then(self::views::index::get) - .untuple_one() - .with(wrap_fn(with_csrf(config.csrf.key, csrf_cookie_name))); - - let login = warp::path("login") - .and(warp::get()) - .and(with_templates) - .and(with_csrf_token) - .and(with_pool) - .and_then(self::views::login::get) - .untuple_one() - .with(wrap_fn(with_csrf(config.csrf.key, csrf_cookie_name))); - - health.or(index).or(login).or(metadata).boxed() - - // app.at("/").nest({ - // let mut views = tide::with_state(state.clone()); - // views.with(state.session_middleware()); - // views.with(state.csrf_middleware()); - // views.with(crate::middlewares::errors); - - // views.at("/").get(self::views::index::get); - - // views - // .at("/login") - // .get(self::views::login::get) - // .post(self::views::login::post); - - // views - // .at("/reauth") - // .get(self::views::reauth::get) - // .post(self::views::reauth::post); - - // views.at("/logout").post(self::views::logout::post); - - // views - // .at("oauth2/authorize") - // .with(BrowserErrorHandler) - // .get(self::oauth2::authorization::get); - - // views - // }); + health(pool.clone()) + .or(oauth2(&config.oauth2)) + .or(views(pool, templates, &config.csrf)) + .boxed() } diff --git a/matrix-authentication-service/src/handlers/oauth2/discovery.rs b/matrix-authentication-service/src/handlers/oauth2/discovery.rs index 2f17f3529..88fa3ecdd 100644 --- a/matrix-authentication-service/src/handlers/oauth2/discovery.rs +++ b/matrix-authentication-service/src/handlers/oauth2/discovery.rs @@ -17,7 +17,7 @@ use warp::{filters::BoxedFilter, Filter, Reply}; use crate::config::OAuth2Config; -pub fn get(config: &OAuth2Config) -> BoxedFilter<(impl Reply,)> { +pub(super) fn filter(config: &OAuth2Config) -> BoxedFilter<(impl Reply,)> { let base = config.issuer.clone(); let metadata = Metadata { authorization_endpoint: base.join("oauth2/authorize").ok(), @@ -32,7 +32,11 @@ pub fn get(config: &OAuth2Config) -> BoxedFilter<(impl Reply,)> { code_challenge_methods_supported: None, }; - warp::any() + let cors = warp::cors().allow_any_origin(); + + warp::get() + .and(warp::path!(".well-known" / "openid-configuration")) .map(move || warp::reply::json(&metadata)) + .with(cors) .boxed() } diff --git a/matrix-authentication-service/src/handlers/oauth2/mod.rs b/matrix-authentication-service/src/handlers/oauth2/mod.rs index d9edb1516..15c28cc79 100644 --- a/matrix-authentication-service/src/handlers/oauth2/mod.rs +++ b/matrix-authentication-service/src/handlers/oauth2/mod.rs @@ -12,5 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use warp::{filters::BoxedFilter, Reply}; + +use crate::config::OAuth2Config; + // pub mod authorization; -pub mod discovery; +mod discovery; + +use self::discovery::filter as discovery; + +pub fn filter(config: &OAuth2Config) -> BoxedFilter<(impl Reply,)> { + discovery(config) +} diff --git a/matrix-authentication-service/src/handlers/views/index.rs b/matrix-authentication-service/src/handlers/views/index.rs index f6f5ecb23..ba6726512 100644 --- a/matrix-authentication-service/src/handlers/views/index.rs +++ b/matrix-authentication-service/src/handlers/views/index.rs @@ -12,22 +12,43 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use sqlx::PgPool; -use tera::Tera; -use warp::{reply::with_header, Rejection, Reply}; +use warp::{filters::BoxedFilter, reply::with_header, wrap_fn, Filter, Rejection, Reply}; -use crate::{errors::WrapError, filters::CsrfToken, templates::CommonContext}; +use crate::{ + config::CsrfConfig, + errors::WrapError, + filters::{csrf::with_csrf, with_pool, with_templates, CsrfToken}, + templates::{CommonContext, Templates}, +}; -pub async fn get( - templates: Arc, +pub(super) fn filter( + pool: PgPool, + templates: Templates, + csrf_config: &CsrfConfig, +) -> BoxedFilter<(impl Reply,)> { + // TODO: this is ugly and leaks + let csrf_cookie_name = Box::leak(Box::new(csrf_config.cookie_name.clone())); + + warp::get() + .and(warp::path::end()) + .and(with_templates(templates)) + .and(csrf_config.to_extract_filter()) + .and(with_pool(pool)) + .and_then(get) + .untuple_one() + .with(wrap_fn(with_csrf(csrf_config.key, csrf_cookie_name))) + .boxed() +} + +async fn get( + templates: Templates, csrf_token: CsrfToken, db: PgPool, ) -> Result<(CsrfToken, impl Reply), Rejection> { let ctx = CommonContext::default() .with_csrf_token(&csrf_token) - .with_session(&db) + .load_session(&db) .await .wrap_error()? .finish() diff --git a/matrix-authentication-service/src/handlers/views/login.rs b/matrix-authentication-service/src/handlers/views/login.rs index 5573882f8..c31674003 100644 --- a/matrix-authentication-service/src/handlers/views/login.rs +++ b/matrix-authentication-service/src/handlers/views/login.rs @@ -12,14 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use serde::Deserialize; use sqlx::PgPool; -use tera::Tera; -use warp::{reply::with_header, Rejection, Reply}; +use warp::{ + filters::BoxedFilter, hyper::Uri, reply::with_header, wrap_fn, Filter, Rejection, Reply, +}; -use crate::{errors::WrapError, filters::CsrfToken, templates::CommonContext}; +use crate::{ + config::CsrfConfig, + csrf::CsrfForm, + errors::WrapError, + filters::{csrf::with_csrf, with_pool, with_templates, CsrfToken}, + storage::login, + templates::{CommonContext, Templates}, +}; #[derive(Deserialize)] struct LoginForm { @@ -27,14 +33,41 @@ struct LoginForm { password: String, } -pub async fn get( - templates: Arc, +pub(super) fn filter( + pool: PgPool, + templates: Templates, + csrf_config: &CsrfConfig, +) -> BoxedFilter<(impl Reply,)> { + // TODO: this is ugly and leaks + let csrf_cookie_name = Box::leak(Box::new(csrf_config.cookie_name.clone())); + + let get = warp::get() + .and(with_templates(templates)) + .and(csrf_config.to_extract_filter()) + .and(with_pool(pool.clone())) + .and_then(get) + .untuple_one() + .with(wrap_fn(with_csrf(csrf_config.key, csrf_cookie_name))); + + let post = warp::post() + .and(csrf_config.to_extract_filter()) + .and(with_pool(pool)) + .and(warp::body::form()) + .and_then(post) + .untuple_one() + .with(wrap_fn(with_csrf(csrf_config.key, csrf_cookie_name))); + + warp::path("login").and(get.or(post)).boxed() +} + +async fn get( + templates: Templates, csrf_token: CsrfToken, db: PgPool, ) -> Result<(CsrfToken, impl Reply), Rejection> { let ctx = CommonContext::default() .with_csrf_token(&csrf_token) - .with_session(&db) + .load_session(&db) .await .wrap_error()? .finish() @@ -48,20 +81,16 @@ pub async fn get( )) } -/* -pub async fn post(mut req: Request) -> tide::Result { - let form: CsrfForm = req.body_form().await?; - let form = form.verify_csrf(&req)?; - let state = req.state(); +async fn post( + csrf_token: CsrfToken, + db: PgPool, + form: CsrfForm, +) -> Result<(CsrfToken, impl Reply), Rejection> { + let form = form.verify_csrf(&csrf_token).wrap_error()?; - let session_info = state - .storage() - .login(&form.username, &form.password) - .await?; + let _session_info = login(&db, &form.username, &form.password) + .await + .wrap_error()?; - let session = req.session_mut(); - session.insert("current_session", session_info.key())?; - - Ok(Redirect::new("/").into()) + Ok((csrf_token, warp::redirect(Uri::from_static("/")))) } -*/ diff --git a/matrix-authentication-service/src/handlers/views/logout.rs b/matrix-authentication-service/src/handlers/views/logout.rs index 5c2180d9e..d1d186706 100644 --- a/matrix-authentication-service/src/handlers/views/logout.rs +++ b/matrix-authentication-service/src/handlers/views/logout.rs @@ -12,16 +12,28 @@ // See the License for the specific language governing permissions and // limitations under the License. -use tide::{Redirect, Request}; +use warp::{filters::BoxedFilter, hyper::Uri, wrap_fn, Filter, Rejection, Reply}; -use crate::{csrf::CsrfForm, state::State}; +use crate::{ + config::CsrfConfig, + csrf::CsrfForm, + errors::WrapError, + filters::{csrf::with_csrf, CsrfToken}, +}; -pub async fn post(mut req: Request) -> tide::Result { - let form: CsrfForm<()> = req.body_form().await?; - form.verify_csrf(&req)?; +pub(super) fn filter(csrf_config: &CsrfConfig) -> BoxedFilter<(impl Reply,)> { + // TODO: this is ugly and leaks + let csrf_cookie_name = Box::leak(Box::new(csrf_config.cookie_name.clone())); - let session = req.session_mut(); - session.remove("current_session"); - - Ok(Redirect::new("/").into()) + warp::post() + .and(warp::path("logout")) + .and(csrf_config.to_extract_filter()) + .and(warp::body::form()) + .and_then(|token: CsrfToken, form: CsrfForm<()>| async { + form.verify_csrf(&token).wrap_error()?; + Ok::<_, Rejection>((token, warp::redirect(Uri::from_static("/login")))) + }) + .untuple_one() + .with(wrap_fn(with_csrf(csrf_config.key, csrf_cookie_name))) + .boxed() } diff --git a/matrix-authentication-service/src/handlers/views/mod.rs b/matrix-authentication-service/src/handlers/views/mod.rs index eda6f9e6a..a67af32ad 100644 --- a/matrix-authentication-service/src/handlers/views/mod.rs +++ b/matrix-authentication-service/src/handlers/views/mod.rs @@ -12,7 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub(super) mod index; -pub(super) mod login; -// pub(super) mod logout; -// pub(super) mod reauth; +use sqlx::PgPool; +use warp::{filters::BoxedFilter, Filter, Reply}; + +use crate::{config::CsrfConfig, templates::Templates}; + +mod index; +mod login; +mod logout; +mod reauth; + +use self::index::filter as index; +use self::login::filter as login; +use self::logout::filter as logout; +use self::reauth::filter as reauth; + +pub(super) fn filter( + pool: PgPool, + templates: Templates, + csrf_config: &CsrfConfig, +) -> BoxedFilter<(impl Reply,)> { + index(pool.clone(), templates.clone(), csrf_config) + .or(login(pool.clone(), templates.clone(), csrf_config)) + .or(logout(csrf_config)) + .or(reauth(pool, templates, csrf_config)) + .boxed() +} diff --git a/matrix-authentication-service/src/handlers/views/reauth.rs b/matrix-authentication-service/src/handlers/views/reauth.rs index 32b0ba1b6..db69971d5 100644 --- a/matrix-authentication-service/src/handlers/views/reauth.rs +++ b/matrix-authentication-service/src/handlers/views/reauth.rs @@ -13,26 +13,91 @@ // limitations under the License. use serde::Deserialize; -use tide::{Redirect, Request, Response}; +use sqlx::PgPool; +use tracing::info; +use warp::{filters::BoxedFilter, reply::with_header, wrap_fn, Filter, Rejection, Reply}; -use crate::{csrf::CsrfForm, state::State, templates::common_context}; +use crate::{ + config::CsrfConfig, + csrf::CsrfForm, + errors::WrapError, + filters::{csrf::with_csrf, with_pool, with_templates, CsrfToken}, + templates::{CommonContext, Templates}, +}; -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] struct ReauthForm { password: String, } -pub async fn get(req: Request) -> tide::Result { - let state = req.state(); - let ctx = common_context(&req).await?; +pub(super) fn filter( + pool: PgPool, + templates: Templates, + csrf_config: &CsrfConfig, +) -> BoxedFilter<(impl Reply,)> { + // TODO: this is ugly and leaks + let csrf_cookie_name = Box::leak(Box::new(csrf_config.cookie_name.clone())); + + let get = warp::get() + .and(with_templates(templates)) + .and(csrf_config.to_extract_filter()) + .and(with_pool(pool.clone())) + .and_then(get) + .untuple_one() + .with(wrap_fn(with_csrf(csrf_config.key, csrf_cookie_name))); + + let post = warp::post() + .and(csrf_config.to_extract_filter()) + .and(with_pool(pool)) + .and(warp::body::form()) + .and_then(post) + .untuple_one() + .with(wrap_fn(with_csrf(csrf_config.key, csrf_cookie_name))); + + warp::path("reauth").and(get.or(post)).boxed() +} + +async fn get( + templates: Templates, + csrf_token: CsrfToken, + db: PgPool, +) -> Result<(CsrfToken, impl Reply), Rejection> { + let ctx = CommonContext::default() + .with_csrf_token(&csrf_token) + .load_session(&db) + .await + .wrap_error()? + .finish() + .wrap_error()?; // TODO: check if there is an existing session - let content = state.templates().render("reauth.html", &ctx)?; - let body = Response::builder(200) - .body(content) - .content_type("text/html") - .into(); - Ok(body) + let content = templates.render("reauth.html", &ctx).wrap_error()?; + Ok(( + csrf_token, + with_header(content, "Content-Type", "text/html"), + )) +} + +async fn post( + csrf_token: CsrfToken, + _db: PgPool, + form: CsrfForm, +) -> Result<(CsrfToken, impl Reply), Rejection> { + let form = form.verify_csrf(&csrf_token).wrap_error()?; + + info!(?form, "reauth"); + + Ok((csrf_token, "unimplemented")) +} + +/* + let form = form.verify_csrf(&csrf_token).wrap_error()?; + + let _session_info = login(&db, &form.username, &form.password) + .await + .wrap_error()?; + + Ok((csrf_token, warp::redirect(Uri::from_static("/")))) } pub async fn post(mut req: Request) -> tide::Result { @@ -52,3 +117,4 @@ pub async fn post(mut req: Request) -> tide::Result { Ok(Redirect::new("/").into()) } +*/ diff --git a/matrix-authentication-service/src/storage/mod.rs b/matrix-authentication-service/src/storage/mod.rs index 329db0398..0fbeb1c00 100644 --- a/matrix-authentication-service/src/storage/mod.rs +++ b/matrix-authentication-service/src/storage/mod.rs @@ -22,7 +22,7 @@ mod user; pub use self::{ client::{Client, ClientLookupError, InvalidRedirectUriError}, - user::{lookup_session, SessionInfo, User}, + user::{login, lookup_session, SessionInfo, User}, }; pub static MIGRATOR: Migrator = sqlx::migrate!(); diff --git a/matrix-authentication-service/src/storage/user.rs b/matrix-authentication-service/src/storage/user.rs index 11fd3737b..7ef3d47fe 100644 --- a/matrix-authentication-service/src/storage/user.rs +++ b/matrix-authentication-service/src/storage/user.rs @@ -47,12 +47,7 @@ impl SessionInfo { impl super::Storage { pub async fn login(&self, username: &str, password: &str) -> anyhow::Result { - let mut txn = self.pool.begin().await?; - let user = lookup_user_by_username(&mut txn, username).await?; - let mut session = start_session(&mut txn, user).await?; - session.last_authd_at = Some(authenticate_session(&mut txn, session.id, password).await?); - txn.commit().await?; - Ok(session) + login(&self.pool, username, password).await } pub async fn register_user(&self, username: &str, password: &str) -> anyhow::Result { @@ -79,6 +74,15 @@ impl super::Storage { } } +pub async fn login(pool: &PgPool, username: &str, password: &str) -> anyhow::Result { + let mut txn = pool.begin().await?; + let user = lookup_user_by_username(&mut txn, username).await?; + let mut session = start_session(&mut txn, user).await?; + session.last_authd_at = Some(authenticate_session(&mut txn, session.id, password).await?); + txn.commit().await?; + Ok(session) +} + pub async fn lookup_session( executor: impl Executor<'_, Database = Postgres>, id: i32, diff --git a/matrix-authentication-service/src/templates.rs b/matrix-authentication-service/src/templates.rs index a7a5bd578..c7bcdeef1 100644 --- a/matrix-authentication-service/src/templates.rs +++ b/matrix-authentication-service/src/templates.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{ops::Deref, sync::Arc}; + use anyhow::Context as _; use serde::Serialize; use sqlx::{Executor, Postgres}; @@ -23,10 +25,24 @@ use crate::{ storage::{lookup_session, SessionInfo}, }; -pub fn load() -> Result { - let path = format!("{}/templates/**/*.{{html,txt}}", env!("CARGO_MANIFEST_DIR")); - info!(%path, "Loading templates"); - Tera::new(&path) +#[derive(Clone)] +pub struct Templates(Arc); + +impl Templates { + pub fn load() -> Result { + let path = format!("{}/templates/**/*.{{html,txt}}", env!("CARGO_MANIFEST_DIR")); + info!(%path, "Loading templates"); + let tera = Tera::new(&path)?; + Ok(Self(Arc::new(tera))) + } +} + +impl Deref for Templates { + type Target = Tera; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } } #[derive(Serialize, Default)] @@ -43,7 +59,14 @@ impl CommonContext { } } - pub async fn with_session<'e>( + pub fn with_session(self, session: SessionInfo) -> Self { + Self { + session: Some(session), + ..self + } + } + + pub async fn load_session<'e>( self, _executor: impl Executor<'e, Database = Postgres>, ) -> anyhow::Result {