diff --git a/Cargo.lock b/Cargo.lock index 8aa07aaa8..44cc3e3e1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -643,6 +643,9 @@ name = "either" version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" +dependencies = [ + "serde", +] [[package]] name = "fake-simd" @@ -1120,6 +1123,15 @@ dependencies = [ "nom 6.1.2", ] +[[package]] +name = "itertools" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "0.4.7" @@ -1230,12 +1242,14 @@ dependencies = [ "figment", "headers", "hyper", + "itertools", "mime", "oauth2-types", "password-hash", "rand 0.8.4", "schemars", "serde", + "serde_json", "serde_with", "serde_yaml", "sqlx", @@ -1405,6 +1419,7 @@ dependencies = [ "serde", "serde_json", "serde_with", + "sqlx", "url", ] @@ -2002,6 +2017,7 @@ version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "336b10da19a12ad094b59d870ebde26a45402e5b470add4b5fd03c5048a32127" dependencies = [ + "indexmap", "itoa", "ryu", "serde", @@ -2239,9 +2255,12 @@ dependencies = [ "either", "futures", "heck", + "hex", "once_cell", "proc-macro2", "quote", + "serde", + "serde_json", "sha2", "sqlx-core", "sqlx-rt", diff --git a/matrix-authentication-service/Cargo.toml b/matrix-authentication-service/Cargo.toml index 281b6983b..c8d805375 100644 --- a/matrix-authentication-service/Cargo.toml +++ b/matrix-authentication-service/Cargo.toml @@ -28,7 +28,7 @@ hyper = { version = "0.14.11", features = ["full"] } tera = "1.12.1" # Database access -sqlx = { version = "0.5.5", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono"] } +sqlx = { version = "0.5.5", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline"] } # Various structure (de)serialization serde = { version = "1.0.126", features = ["derive"] } @@ -48,6 +48,7 @@ password-hash = { version = "0.2.2", features = ["std"] } data-encoding = "2.3.2" chrono = { version = "0.4.19", features = ["serde"] } url = { version = "2.2.2", features = ["serde"] } +itertools = "0.10.1" mime = "0.3.16" rand = "0.8.4" bincode = "1.3.3" @@ -56,3 +57,4 @@ cookie = "0.15.1" chacha20poly1305 = { version = "0.8.1", features = ["std"] } oauth2-types = { path = "../oauth2-types" } +serde_json = "1.0.66" diff --git a/matrix-authentication-service/migrations/20210731130515_oauth2_sessions.down.sql b/matrix-authentication-service/migrations/20210731130515_oauth2_sessions.down.sql new file mode 100644 index 000000000..305ceb408 --- /dev/null +++ b/matrix-authentication-service/migrations/20210731130515_oauth2_sessions.down.sql @@ -0,0 +1,17 @@ +-- Copyright 2021 The Matrix.org Foundation C.I.C. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +DROP TRIGGER set_timestamp ON TABLE oauth2_sessions; +DROP TABLE oauth2_codes; +DROP TABLE oauth2_sessions; diff --git a/matrix-authentication-service/migrations/20210731130515_oauth2_sessions.up.sql b/matrix-authentication-service/migrations/20210731130515_oauth2_sessions.up.sql new file mode 100644 index 000000000..aef6a3cb4 --- /dev/null +++ b/matrix-authentication-service/migrations/20210731130515_oauth2_sessions.up.sql @@ -0,0 +1,41 @@ +-- Copyright 2021 The Matrix.org Foundation C.I.C. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +CREATE TABLE oauth2_sessions ( + "id" BIGSERIAL PRIMARY KEY, + "user_session_id" BIGINT REFERENCES user_sessions (id) ON DELETE CASCADE, + "client_id" TEXT NOT NULL, + "scope" TEXT NOT NULL, + "state" TEXT, + "nonce" TEXT, + + "created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(), + "updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now() +); + +CREATE TRIGGER set_timestamp + BEFORE UPDATE ON oauth2_sessions + FOR EACH ROW + EXECUTE PROCEDURE trigger_set_timestamp(); + +CREATE TABLE oauth2_codes ( + "id" BIGSERIAL PRIMARY KEY, + "oauth2_session_id" BIGINT NOT NULL REFERENCES oauth2_sessions (id) ON DELETE CASCADE, + "code" TEXT UNIQUE NOT NULL, + "code_challenge_method" SMALLINT, + "code_challenge" TEXT, + + CHECK (("code_challenge" IS NULL AND "code_challenge_method" IS NULL) + OR ("code_challenge" IS NOT NULL AND "code_challenge_method" IS NOT NULL)) +); diff --git a/matrix-authentication-service/sqlx-data.json b/matrix-authentication-service/sqlx-data.json new file mode 100644 index 000000000..e6114982c --- /dev/null +++ b/matrix-authentication-service/sqlx-data.json @@ -0,0 +1,291 @@ +{ + "db": "PostgreSQL", + "037ba804eabd0b4290d87d1de37054f358eb11397d3a8e4b69a81cdce0a178e0": { + "query": "\n SELECT id, username\n FROM users\n WHERE username = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "username", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false + ] + } + }, + "138c3297a66107d8428ca10d04f9a4dd75faf9c1d3f84bcedd3b09f55dd84206": { + "query": "\n INSERT INTO oauth2_codes\n (oauth2_session_id, code, code_challenge_method, code_challenge)\n VALUES\n ($1, $2, $3, $4)\n RETURNING\n id, oauth2_session_id, code, code_challenge_method, code_challenge\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "oauth2_session_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "code", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "code_challenge_method", + "type_info": "Int2" + }, + { + "ordinal": 4, + "name": "code_challenge", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Int8", + "Text", + "Int2", + "Text" + ] + }, + "nullable": [ + false, + false, + false, + true, + true + ] + } + }, + "34e61467c9d30fa18f1f96990358f87aeb4ffa5fe0364fe499f6132017e0f20b": { + "query": "\n INSERT INTO oauth2_sessions \n (user_session_id, client_id, scope, state, nonce)\n VALUES\n ($1, $2, $3, $4, $5)\n RETURNING\n id, user_session_id, client_id, scope, state, nonce, created_at, updated_at\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "user_session_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "client_id", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "scope", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "state", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "nonce", + "type_info": "Text" + }, + { + "ordinal": 6, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 7, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Int8", + "Text", + "Text", + "Text", + "Text" + ] + }, + "nullable": [ + false, + true, + false, + false, + true, + true, + false, + false + ] + } + }, + "35bedaa6fdf7ac91d54b458b4637f2182c2f82be3e2f80cd2db934ee279a7f2a": { + "query": "\n SELECT id, username\n FROM users\n WHERE id = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "username", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false + ] + } + }, + "4f925a277d73df779360f81e0cf5d7983b50ebe744f461559dd561b7e36c20d4": { + "query": "\n SELECT\n s.id,\n u.id as user_id,\n u.username,\n s.active,\n s.created_at,\n a.created_at as \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u \n ON s.user_id = u.id\n LEFT JOIN user_session_authentications a\n ON a.session_id = s.id\n WHERE s.id = $1 AND s.active\n ORDER BY a.created_at DESC\n LIMIT 1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "user_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "username", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "active", + "type_info": "Bool" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "last_authd_at?", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + false + ] + } + }, + "9ba45ab114b656105cc46b0c10fb05769860fcdc05eaf54d6225640fb914dab9": { + "query": "\n INSERT INTO user_session_authentications (session_id)\n VALUES ($1)\n RETURNING created_at\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "created_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false + ] + } + }, + "a09dfe1019110f2ec6eba0d35bafa467ab4b7980dd8b556826f03863f8edb0ab": { + "query": "UPDATE user_sessions SET active = FALSE WHERE id = $1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [] + } + }, + "a552eee8a8e5ffdee4d4789c634851bd64780dfe730807aac20142d7cd643814": { + "query": "\n SELECT u.hashed_password\n FROM user_sessions s\n INNER JOIN users u\n ON u.id = s.user_id \n WHERE s.id = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "hashed_password", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false + ] + } + }, + "f9a09ff53b6f221649f4f050e3d5ade114f852ddf50a78610a6c0ef0689af681": { + "query": "\n INSERT INTO users (username, hashed_password)\n VALUES ($1, $2)\n RETURNING id\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Text", + "Text" + ] + }, + "nullable": [ + false + ] + } + } +} \ No newline at end of file diff --git a/matrix-authentication-service/src/config/oauth2.rs b/matrix-authentication-service/src/config/oauth2.rs index 60fd5443c..fda159559 100644 --- a/matrix-authentication-service/src/config/oauth2.rs +++ b/matrix-authentication-service/src/config/oauth2.rs @@ -15,12 +15,13 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_with::skip_serializing_none; +use thiserror::Error; use url::Url; use super::ConfigurationSection; #[skip_serializing_none] -#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct OAuth2ClientConfig { pub client_id: String, @@ -28,11 +29,37 @@ pub struct OAuth2ClientConfig { pub redirect_uris: Option>, } +#[derive(Debug, Error)] +#[error("Invalid redirect URI")] +pub struct InvalidRedirectUriError; + +impl OAuth2ClientConfig { + pub fn resolve_redirect_uri<'a>( + &'a self, + suggested_uri: &'a Option, + ) -> Result<&'a Url, InvalidRedirectUriError> { + match (suggested_uri, &self.redirect_uris) { + (None, None) => Err(InvalidRedirectUriError), + (None, Some(redirect_uris)) => { + redirect_uris.iter().next().ok_or(InvalidRedirectUriError) + } + (Some(suggested_uri), None) => Ok(suggested_uri), + (Some(suggested_uri), Some(redirect_uris)) => { + if redirect_uris.contains(suggested_uri) { + Ok(suggested_uri) + } else { + Err(InvalidRedirectUriError) + } + } + } + } +} + fn default_oauth2_issuer() -> Url { "http://[::]:8080".parse().unwrap() } -#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct OAuth2Config { #[serde(default = "default_oauth2_issuer")] pub issuer: Url, diff --git a/matrix-authentication-service/src/handlers/health.rs b/matrix-authentication-service/src/handlers/health.rs index 24fa5ad5c..a1ed5be4b 100644 --- a/matrix-authentication-service/src/handlers/health.rs +++ b/matrix-authentication-service/src/handlers/health.rs @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use hyper::header::CONTENT_TYPE; +use mime::TEXT_PLAIN; use sqlx::PgPool; use tracing::{info_span, Instrument}; -use warp::{filters::BoxedFilter, Filter, Rejection, Reply}; +use warp::{filters::BoxedFilter, reply::with_header, Filter, Rejection, Reply}; use crate::{errors::WrapError, filters::with_pool}; @@ -34,5 +36,5 @@ async fn get(pool: PgPool) -> Result { .await .wrap_error()?; - Ok(Box::new("ok")) + Ok(with_header("ok", CONTENT_TYPE, TEXT_PLAIN.to_string())) } diff --git a/matrix-authentication-service/src/handlers/mod.rs b/matrix-authentication-service/src/handlers/mod.rs index c575e136b..462ec313b 100644 --- a/matrix-authentication-service/src/handlers/mod.rs +++ b/matrix-authentication-service/src/handlers/mod.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use hyper::StatusCode; use sqlx::PgPool; use warp::{filters::BoxedFilter, Filter}; @@ -30,9 +29,9 @@ pub fn root( config: &RootConfig, ) -> BoxedFilter<(impl warp::Reply,)> { health(pool) - .or(oauth2(&config.oauth2)) + .or(oauth2(pool, &config.oauth2, &config.cookies)) .or(views(pool, templates, &config.csrf, &config.cookies)) - .or(warp::get().map(|| StatusCode::NOT_FOUND)) + //.or(warp::get().map(|| StatusCode::NOT_FOUND)) <- This messes up the error reporting .with(warp::log(module_path!())) .boxed() } diff --git a/matrix-authentication-service/src/handlers/oauth2/authorization.rs b/matrix-authentication-service/src/handlers/oauth2/authorization.rs index 65fecbd37..79f1f4d23 100644 --- a/matrix-authentication-service/src/handlers/oauth2/authorization.rs +++ b/matrix-authentication-service/src/handlers/oauth2/authorization.rs @@ -12,13 +12,105 @@ // See the License for the specific language governing permissions and // limitations under the License. -use oauth2_types::requests::AuthorizationRequest; -use tide::{Body, Request, Response}; +use data_encoding::BASE64URL_NOPAD; +use itertools::Itertools; +use oauth2_types::{ + pkce, + requests::{AuthorizationRequest, ResponseType}, +}; +use serde::Deserialize; +use sqlx::PgPool; +use warp::{filters::BoxedFilter, Filter, Rejection, Reply}; -use crate::state::State; +use crate::{ + config::{CookiesConfig, OAuth2ClientConfig, OAuth2Config}, + errors::WrapError, + filters::{session::with_optional_session, with_pool}, + storage::{oauth2::start_session, SessionInfo}, +}; -pub async fn get(req: Request) -> tide::Result { - let params: AuthorizationRequest = req.query()?; - let body = Body::from_json(¶ms)?; - Ok(Response::builder(200).body(body).build()) +#[derive(Deserialize)] +struct Params { + #[serde(flatten)] + auth: AuthorizationRequest, + + #[serde(flatten)] + pkce: Option, +} + +pub fn filter( + pool: &PgPool, + oauth2_config: &OAuth2Config, + cookies_config: &CookiesConfig, +) -> BoxedFilter<(impl Reply,)> { + let clients = oauth2_config.clients.clone(); + warp::get() + .and(warp::path!("oauth2" / "authorize")) + .map(move || clients.clone()) + .and(warp::query()) + .and(with_optional_session(pool, cookies_config)) + .and(with_pool(pool)) + .and_then(get) + .boxed() +} + +async fn get( + clients: Vec, + params: Params, + maybe_session: Option, + pool: PgPool, +) -> Result { + // First, find out what client it is + let client = clients + .into_iter() + .find(|client| client.client_id == params.auth.client_id) + .ok_or_else(|| anyhow::anyhow!("could not find client")) + .wrap_error()?; + + // Then, figure out the redirect URI + let redirect_uri = client + .resolve_redirect_uri(¶ms.auth.redirect_uri) + .wrap_error()?; + + // Start a DB transaction + let mut txn = pool.begin().await.wrap_error()?; + let maybe_session_id = maybe_session.as_ref().map(SessionInfo::key); + + let scope: String = { + let it = params.auth.scope.iter().map(ToString::to_string); + Itertools::intersperse(it, " ".to_string()).collect() + }; + + let oauth2_session = start_session( + &mut txn, + maybe_session_id, + &client.client_id, + &scope, + params.auth.state.as_deref(), + params.auth.nonce.as_deref(), + ) + .await + .wrap_error()?; + + let code = if params.auth.response_type.contains(&ResponseType::Code) { + // 192bit random bytes encoded in base64, which gives a 32 character code + let code: [u8; 24] = rand::random(); + let code = BASE64URL_NOPAD.encode(&code); + Some( + oauth2_session + .add_code(&mut txn, &code, ¶ms.pkce) + .await + .wrap_error()?, + ) + } else { + None + }; + + txn.commit().await.wrap_error()?; + + Ok(warp::reply::json(&serde_json::json!({ + "session": oauth2_session, + "code": code, + "redirect_uri": redirect_uri, + }))) } diff --git a/matrix-authentication-service/src/handlers/oauth2/mod.rs b/matrix-authentication-service/src/handlers/oauth2/mod.rs index 15c28cc79..b012eff67 100644 --- a/matrix-authentication-service/src/handlers/oauth2/mod.rs +++ b/matrix-authentication-service/src/handlers/oauth2/mod.rs @@ -12,15 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -use warp::{filters::BoxedFilter, Reply}; +use sqlx::PgPool; +use warp::{filters::BoxedFilter, Filter, Reply}; -use crate::config::OAuth2Config; +use crate::config::{CookiesConfig, OAuth2Config}; -// pub mod authorization; +mod authorization; mod discovery; -use self::discovery::filter as discovery; +use self::{authorization::filter as authorization, discovery::filter as discovery}; -pub fn filter(config: &OAuth2Config) -> BoxedFilter<(impl Reply,)> { - discovery(config) +pub fn filter( + pool: &PgPool, + oauth2_config: &OAuth2Config, + cookies_config: &CookiesConfig, +) -> BoxedFilter<(impl Reply,)> { + discovery(oauth2_config) + .or(authorization(pool, oauth2_config, cookies_config)) + .boxed() } diff --git a/matrix-authentication-service/src/main.rs b/matrix-authentication-service/src/main.rs index c742e9ed6..fafe614c1 100644 --- a/matrix-authentication-service/src/main.rs +++ b/matrix-authentication-service/src/main.rs @@ -16,7 +16,8 @@ #![deny(clippy::all)] #![warn(clippy::pedantic)] #![allow(clippy::module_name_repetitions)] -#![allow(clippy::unused_async)] +#![allow(clippy::unused_async)] // Some warp filters need that +#![allow(clippy::used_underscore_binding)] // This is needed by sqlx macros use anyhow::Context; use clap::Clap; diff --git a/matrix-authentication-service/src/storage/mod.rs b/matrix-authentication-service/src/storage/mod.rs index d42c0cb92..d3b3b8e5c 100644 --- a/matrix-authentication-service/src/storage/mod.rs +++ b/matrix-authentication-service/src/storage/mod.rs @@ -14,6 +14,7 @@ use sqlx::migrate::Migrator; +pub mod oauth2; mod user; pub use self::user::{login, lookup_active_session, register_user, SessionInfo, User}; diff --git a/matrix-authentication-service/src/storage/oauth2.rs b/matrix-authentication-service/src/storage/oauth2.rs new file mode 100644 index 000000000..f6cd00fd8 --- /dev/null +++ b/matrix-authentication-service/src/storage/oauth2.rs @@ -0,0 +1,111 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::Context; +use chrono::{DateTime, Utc}; +use oauth2_types::pkce; +use serde::Serialize; +use sqlx::{Executor, FromRow, Postgres}; + +#[derive(FromRow, Serialize)] +pub struct OAuth2Session { + id: i64, + user_session_id: Option, + client_id: String, + scope: String, + state: Option, + nonce: Option, + + created_at: DateTime, + updated_at: DateTime, +} + +impl OAuth2Session { + pub async fn add_code<'e>( + &self, + executor: impl Executor<'e, Database = Postgres>, + code: &str, + code_challenge: &Option, + ) -> anyhow::Result { + add_code(executor, self.id, code, code_challenge).await + } +} + +pub async fn start_session( + executor: impl Executor<'_, Database = Postgres>, + optional_session_id: Option, + client_id: &str, + scope: &str, + state: Option<&str>, + nonce: Option<&str>, +) -> anyhow::Result { + sqlx::query_as!( + OAuth2Session, + r#" + INSERT INTO oauth2_sessions + (user_session_id, client_id, scope, state, nonce) + VALUES + ($1, $2, $3, $4, $5) + RETURNING + id, user_session_id, client_id, scope, state, nonce, created_at, updated_at + "#, + optional_session_id, + client_id, + scope, + state, + nonce, + ) + .fetch_one(executor) + .await + .context("could not insert oauth2 session") +} + +#[derive(FromRow, Serialize)] +pub struct OAuth2Code { + id: i64, + oauth2_session_id: i64, + code: String, + code_challenge: Option, + code_challenge_method: Option, +} + +pub async fn add_code( + executor: impl Executor<'_, Database = Postgres>, + oauth2_session_id: i64, + code: &str, + code_challenge: &Option, +) -> anyhow::Result { + let code_challenge_method = code_challenge + .as_ref() + .map(|c| c.code_challenge_method as i16); + let code_challenge = code_challenge.as_ref().map(|c| &c.code_challenge); + sqlx::query_as!( + OAuth2Code, + r#" + INSERT INTO oauth2_codes + (oauth2_session_id, code, code_challenge_method, code_challenge) + VALUES + ($1, $2, $3, $4) + RETURNING + id, oauth2_session_id, code, code_challenge_method, code_challenge + "#, + oauth2_session_id, + code, + code_challenge_method, + code_challenge, + ) + .fetch_one(executor) + .await + .context("could not insert oauth2 authorization code") +} diff --git a/matrix-authentication-service/src/storage/user.rs b/matrix-authentication-service/src/storage/user.rs index b66773f98..78e849767 100644 --- a/matrix-authentication-service/src/storage/user.rs +++ b/matrix-authentication-service/src/storage/user.rs @@ -74,7 +74,8 @@ pub async fn lookup_active_session( executor: impl Executor<'_, Database = Postgres>, id: i64, ) -> anyhow::Result { - sqlx::query_as( + sqlx::query_as!( + SessionInfo, r#" SELECT s.id, @@ -82,7 +83,7 @@ pub async fn lookup_active_session( u.username, s.active, s.created_at, - a.created_at as last_authd_at + a.created_at as "last_authd_at?" FROM user_sessions s INNER JOIN users u ON s.user_id = u.id @@ -92,8 +93,8 @@ pub async fn lookup_active_session( ORDER BY a.created_at DESC LIMIT 1 "#, + id, ) - .bind(id) .fetch_one(executor) .await .context("could not fetch session") @@ -131,7 +132,7 @@ pub async fn authenticate_session( password: &str, ) -> anyhow::Result> { // First, fetch the hashed password from the user associated with that session - let hashed_password: String = sqlx::query_scalar( + let hashed_password: String = sqlx::query_scalar!( r#" SELECT u.hashed_password FROM user_sessions s @@ -139,8 +140,8 @@ pub async fn authenticate_session( ON u.id = s.user_id WHERE s.id = $1 "#, + session_id, ) - .bind(session_id) .fetch_one(txn.borrow_mut()) .await .context("could not fetch user password hash")?; @@ -151,14 +152,14 @@ pub async fn authenticate_session( hasher.verify_password(&[&context], &password)?; // That went well, let's insert the auth info - let created_at: DateTime = sqlx::query_scalar( + let created_at: DateTime = sqlx::query_scalar!( r#" INSERT INTO user_session_authentications (session_id) VALUES ($1) RETURNING created_at "#, + session_id, ) - .bind(session_id) .fetch_one(txn.borrow_mut()) .await .context("could not save session auth")?; @@ -175,15 +176,15 @@ pub async fn register_user( let salt = SaltString::generate(&mut OsRng); let hashed_password = PasswordHash::generate(phf, password, salt.as_str())?; - let id: i64 = sqlx::query_scalar( + let id: i64 = sqlx::query_scalar!( r#" INSERT INTO users (username, hashed_password) VALUES ($1, $2) RETURNING id "#, + username, + hashed_password.to_string(), ) - .bind(&username) - .bind(&hashed_password.to_string()) .fetch_one(executor) .instrument(info_span!("Register user")) .await @@ -199,8 +200,7 @@ pub async fn end_session( executor: impl Executor<'_, Database = Postgres>, id: i64, ) -> anyhow::Result<()> { - let res = sqlx::query("UPDATE user_sessions SET active = FALSE WHERE id = $1") - .bind(&id) + let res = sqlx::query!("UPDATE user_sessions SET active = FALSE WHERE id = $1", id) .execute(executor) .instrument(info_span!("End session")) .await @@ -218,14 +218,15 @@ pub async fn lookup_user_by_id( executor: impl Executor<'_, Database = Postgres>, id: i64, ) -> anyhow::Result { - sqlx::query_as( + sqlx::query_as!( + User, r#" SELECT id, username FROM users WHERE id = $1 "#, + id ) - .bind(&id) .fetch_one(executor) .instrument(info_span!("Fetch user")) .await @@ -236,14 +237,15 @@ pub async fn lookup_user_by_username( executor: impl Executor<'_, Database = Postgres>, username: &str, ) -> anyhow::Result { - sqlx::query_as( + sqlx::query_as!( + User, r#" SELECT id, username FROM users WHERE username = $1 "#, + username, ) - .bind(&username) .fetch_one(executor) .instrument(info_span!("Fetch user")) .await diff --git a/oauth2-types/Cargo.toml b/oauth2-types/Cargo.toml index ec45e94a5..f387cfaa9 100644 --- a/oauth2-types/Cargo.toml +++ b/oauth2-types/Cargo.toml @@ -14,3 +14,4 @@ url = { version = "2.2.2", features = ["serde"] } parse-display = "0.5.1" indoc = "1.0.3" serde_with = "1.9.4" +sqlx = "0.5.5" diff --git a/oauth2-types/src/pkce.rs b/oauth2-types/src/pkce.rs index d8cc6692e..c036cdb9d 100644 --- a/oauth2-types/src/pkce.rs +++ b/oauth2-types/src/pkce.rs @@ -14,6 +14,7 @@ use parse_display::{Display, FromStr}; use serde::{Deserialize, Serialize}; +use sqlx::Type; #[derive( Debug, @@ -28,15 +29,17 @@ use serde::{Deserialize, Serialize}; FromStr, Serialize, Deserialize, + Type, )] +#[repr(i8)] pub enum CodeChallengeMethod { #[serde(rename = "plain")] #[display("plain")] - Plain, + Plain = 0, #[serde(rename = "S256")] #[display("S256")] - S256, + S256 = 1, } #[derive(Serialize, Deserialize)] diff --git a/oauth2-types/src/requests.rs b/oauth2-types/src/requests.rs index 4e6cdd8c6..e92628005 100644 --- a/oauth2-types/src/requests.rs +++ b/oauth2-types/src/requests.rs @@ -115,20 +115,20 @@ pub enum Prompt { #[derive(Serialize, Deserialize)] pub struct AuthorizationRequest { #[serde_as(as = "StringWithSeparator::")] - response_type: HashSet, + pub response_type: HashSet, - client_id: String, + pub client_id: String, - redirect_uri: Option, + pub redirect_uri: Option, #[serde_as(as = "StringWithSeparator::")] - scope: HashSet, + pub scope: HashSet, - state: Option, + pub state: Option, response_mode: Option, - nonce: Option, + pub nonce: Option, display: Option,