diff --git a/Cargo.lock b/Cargo.lock index ae84ade89..7a11898fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1255,6 +1255,7 @@ dependencies = [ "csrf", "data-encoding", "figment", + "mime", "oauth2-types", "serde", "tera", @@ -1273,6 +1274,12 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" +[[package]] +name = "mime" +version = "0.3.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" + [[package]] name = "num-integer" version = "0.1.44" diff --git a/matrix-authentication-service/Cargo.toml b/matrix-authentication-service/Cargo.toml index ef0838a71..c57a463a2 100644 --- a/matrix-authentication-service/Cargo.toml +++ b/matrix-authentication-service/Cargo.toml @@ -21,3 +21,4 @@ csrf = "0.4.0" data-encoding = "2.3.2" time = "0.2.27" tide-tracing = "0.0.11" +mime = "0.3.16" diff --git a/matrix-authentication-service/src/handlers/mod.rs b/matrix-authentication-service/src/handlers/mod.rs index 0123e1011..cdeb62d66 100644 --- a/matrix-authentication-service/src/handlers/mod.rs +++ b/matrix-authentication-service/src/handlers/mod.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use async_trait::async_trait; use serde::Deserialize; use thiserror::Error; @@ -102,6 +100,7 @@ pub fn install(app: &mut Server) { let mut views = tide::with_state(state.clone()); views.with(state.session_middleware()); views.with(crate::middlewares::csrf); + views.with(crate::middlewares::errors); views.at("/").get(self::views::index::get); views .at("/login") diff --git a/matrix-authentication-service/src/middlewares/errors.rs b/matrix-authentication-service/src/middlewares/errors.rs new file mode 100644 index 000000000..d1eaeda9b --- /dev/null +++ b/matrix-authentication-service/src/middlewares/errors.rs @@ -0,0 +1,184 @@ +use std::cmp::Reverse; +use std::future::Future; +use std::pin::Pin; + +use mime::{Mime, STAR}; +use serde::Serialize; +use tera::Context; +use tide::{ + http::headers::{ACCEPT, LOCATION}, + Body, Request, StatusCode, +}; +use tracing::debug; + +use crate::state::State; +use crate::templates::common_context; + +/// Get the weight parameter for a mime type from 0 to 1000 +fn get_weight(mime: &Mime) -> usize { + let q = mime + .get_param("q") + .map(|q| q.as_str().parse().unwrap_or(0.0)) + .unwrap_or(1.0_f64) + .min(1.0) + .max(0.0); + + // Weight have a 3 digit precision so we can multiply by 1000 and cast to int + (q * 1000.0) as _ +} + +/// Find what content type should be used for a given request +fn preferred_mime_type<'a>( + request: &Request, + supported_types: &'a [Mime], +) -> Option<&'a Mime> { + let accept = request.header(ACCEPT)?; + // Parse the Accept header as a list of mime types with their associated weight + let accepted_types: Vec<(Mime, usize)> = { + let v: Option> = accept + .into_iter() + .map(|value| value.as_str().split(',')) + .flatten() + .map(|mime| { + mime.trim().parse().ok().map(|mime| { + let q = get_weight(&mime); + (mime, q) + }) + }) + .collect(); + let mut v = v?; + v.sort_by_key(|(_, weight)| Reverse(*weight)); + v + }; + + // For each supported content type, find out if it is accepted with what weight and specificity + let mut types: Vec<_> = supported_types + .iter() + .enumerate() + .filter_map(|(index, supported)| { + accepted_types.iter().find_map(|(accepted, weight)| { + if accepted.type_() == supported.type_() + && accepted.subtype() == supported.subtype() + { + // Accept: text/html + Some((supported, *weight, 2_usize, index)) + } else if accepted.type_() == supported.type_() && accepted.subtype() == STAR { + // Accept: text/* + Some((supported, *weight, 1, index)) + } else if accepted.type_() == STAR && accepted.subtype() == STAR { + // Accept: */* + Some((supported, *weight, 0, index)) + } else { + None + } + }) + }) + .collect(); + + types.sort_by_key(|(_, weight, specificity, index)| { + (Reverse(*weight), Reverse(*specificity), *index) + }); + + types.first().map(|(mime, _, _, _)| *mime) +} + +#[derive(Serialize)] +struct ErrorContext { + #[serde(skip_serializing_if = "Option::is_none")] + code: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + details: Option, +} + +impl ErrorContext { + fn should_render(&self) -> bool { + self.code.is_some() || self.description.is_some() || self.details.is_some() + } +} + +pub fn middleware<'a>( + request: tide::Request, + next: tide::Next<'a, State>, +) -> Pin + Send + 'a>> { + Box::pin(async { + let content_type = preferred_mime_type( + &request, + &[mime::TEXT_PLAIN, mime::TEXT_HTML, mime::APPLICATION_JSON], + ); + debug!("Content-Type from Accept: {:?}", content_type); + + // TODO: We should not clone here + let templates = request.state().templates().clone(); + + // TODO: This context should probably be comptuted somewhere else + let pctx = common_context(&request).await?.clone(); + + let mut response = next.run(request).await; + + // Find out what message should be displayed from the response status code + let (code, description) = match response.status() { + StatusCode::NotFound => (Some("Not found".to_string()), None), + StatusCode::MethodNotAllowed => (Some("Method not allowed".to_string()), None), + StatusCode::Found + | StatusCode::PermanentRedirect + | StatusCode::TemporaryRedirect + | StatusCode::SeeOther => { + let description = response.header(LOCATION).map(|loc| format!("To {}", loc)); + (Some("Redirecting".to_string()), description) + } + StatusCode::InternalServerError => (Some("Internal server error".to_string()), None), + _ => (None, None), + }; + + // If there is an error associated to the response, format it in a nice way with + // a backtrace if we have one + let details = response.take_error().map(|err| { + format!( + "{}{}", + err, + err.backtrace() + .map(|bt| format!("\nBacktrace:\n{}", bt.to_string())) + .unwrap_or_default() + ) + }); + + let error_context = ErrorContext { + code, + description, + details, + }; + + // This is the case if one of the code, description or details is not None + if error_context.should_render() { + match content_type { + Some(c) if c == &mime::APPLICATION_JSON => { + response.set_body(Body::from_json(&error_context)?); + response.set_content_type("application/json"); + } + Some(c) if c == &mime::TEXT_HTML => { + let mut ctx = Context::from_serialize(&error_context)?; + ctx.extend(pctx); + response.set_body(templates.render("error.html", &ctx)?); + response.set_content_type("text/html"); + } + Some(c) if c == &mime::TEXT_PLAIN => { + let mut ctx = Context::from_serialize(&error_context)?; + ctx.extend(pctx); + response.set_body(templates.render("error.txt", &ctx)?); + response.set_content_type("text/plain"); + } + _ => { + response.set_body("Unsupported Content-Type in Accept header"); + response.set_content_type("text/plain"); + response.set_status(StatusCode::NotAcceptable); + } + } + } + + Ok(response) + }) +} diff --git a/matrix-authentication-service/src/middlewares/mod.rs b/matrix-authentication-service/src/middlewares/mod.rs index 075819617..88bff092a 100644 --- a/matrix-authentication-service/src/middlewares/mod.rs +++ b/matrix-authentication-service/src/middlewares/mod.rs @@ -1,3 +1,5 @@ mod csrf; +mod errors; pub use self::csrf::middleware as csrf; +pub use self::errors::middleware as errors; diff --git a/matrix-authentication-service/src/templates.rs b/matrix-authentication-service/src/templates.rs index d82e58ae7..a2720c70a 100644 --- a/matrix-authentication-service/src/templates.rs +++ b/matrix-authentication-service/src/templates.rs @@ -6,7 +6,7 @@ use tracing::info; use crate::state::State; pub fn load() -> Result { - let path = format!("{}/templates/**/*.html", env!("CARGO_MANIFEST_DIR")); + let path = format!("{}/templates/**/*.{{html,txt}}", env!("CARGO_MANIFEST_DIR")); info!(%path, "Loading templates"); Tera::new(&path) } diff --git a/matrix-authentication-service/templates/base.html b/matrix-authentication-service/templates/base.html index bfd521855..a64116c3b 100644 --- a/matrix-authentication-service/templates/base.html +++ b/matrix-authentication-service/templates/base.html @@ -7,7 +7,7 @@ -