starting the oauth2 authorization flow

also enable compile-time validation of queries
This commit is contained in:
Quentin Gliech
2021-07-31 23:22:17 +02:00
parent 1cfd74dae5
commit dcc84e1083
17 changed files with 662 additions and 46 deletions
Generated
+19
View File
@@ -643,6 +643,9 @@ name = "either"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457"
dependencies = [
"serde",
]
[[package]]
name = "fake-simd"
@@ -1120,6 +1123,15 @@ dependencies = [
"nom 6.1.2",
]
[[package]]
name = "itertools"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "0.4.7"
@@ -1230,12 +1242,14 @@ dependencies = [
"figment",
"headers",
"hyper",
"itertools",
"mime",
"oauth2-types",
"password-hash",
"rand 0.8.4",
"schemars",
"serde",
"serde_json",
"serde_with",
"serde_yaml",
"sqlx",
@@ -1405,6 +1419,7 @@ dependencies = [
"serde",
"serde_json",
"serde_with",
"sqlx",
"url",
]
@@ -2002,6 +2017,7 @@ version = "1.0.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "336b10da19a12ad094b59d870ebde26a45402e5b470add4b5fd03c5048a32127"
dependencies = [
"indexmap",
"itoa",
"ryu",
"serde",
@@ -2239,9 +2255,12 @@ dependencies = [
"either",
"futures",
"heck",
"hex",
"once_cell",
"proc-macro2",
"quote",
"serde",
"serde_json",
"sha2",
"sqlx-core",
"sqlx-rt",
+3 -1
View File
@@ -28,7 +28,7 @@ hyper = { version = "0.14.11", features = ["full"] }
tera = "1.12.1"
# Database access
sqlx = { version = "0.5.5", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono"] }
sqlx = { version = "0.5.5", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline"] }
# Various structure (de)serialization
serde = { version = "1.0.126", features = ["derive"] }
@@ -48,6 +48,7 @@ password-hash = { version = "0.2.2", features = ["std"] }
data-encoding = "2.3.2"
chrono = { version = "0.4.19", features = ["serde"] }
url = { version = "2.2.2", features = ["serde"] }
itertools = "0.10.1"
mime = "0.3.16"
rand = "0.8.4"
bincode = "1.3.3"
@@ -56,3 +57,4 @@ cookie = "0.15.1"
chacha20poly1305 = { version = "0.8.1", features = ["std"] }
oauth2-types = { path = "../oauth2-types" }
serde_json = "1.0.66"
@@ -0,0 +1,17 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
DROP TRIGGER set_timestamp ON TABLE oauth2_sessions;
DROP TABLE oauth2_codes;
DROP TABLE oauth2_sessions;
@@ -0,0 +1,41 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
CREATE TABLE oauth2_sessions (
"id" BIGSERIAL PRIMARY KEY,
"user_session_id" BIGINT REFERENCES user_sessions (id) ON DELETE CASCADE,
"client_id" TEXT NOT NULL,
"scope" TEXT NOT NULL,
"state" TEXT,
"nonce" TEXT,
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
);
CREATE TRIGGER set_timestamp
BEFORE UPDATE ON oauth2_sessions
FOR EACH ROW
EXECUTE PROCEDURE trigger_set_timestamp();
CREATE TABLE oauth2_codes (
"id" BIGSERIAL PRIMARY KEY,
"oauth2_session_id" BIGINT NOT NULL REFERENCES oauth2_sessions (id) ON DELETE CASCADE,
"code" TEXT UNIQUE NOT NULL,
"code_challenge_method" SMALLINT,
"code_challenge" TEXT,
CHECK (("code_challenge" IS NULL AND "code_challenge_method" IS NULL)
OR ("code_challenge" IS NOT NULL AND "code_challenge_method" IS NOT NULL))
);
@@ -0,0 +1,291 @@
{
"db": "PostgreSQL",
"037ba804eabd0b4290d87d1de37054f358eb11397d3a8e4b69a81cdce0a178e0": {
"query": "\n SELECT id, username\n FROM users\n WHERE username = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "username",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false
]
}
},
"138c3297a66107d8428ca10d04f9a4dd75faf9c1d3f84bcedd3b09f55dd84206": {
"query": "\n INSERT INTO oauth2_codes\n (oauth2_session_id, code, code_challenge_method, code_challenge)\n VALUES\n ($1, $2, $3, $4)\n RETURNING\n id, oauth2_session_id, code, code_challenge_method, code_challenge\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "oauth2_session_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "code",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "code_challenge_method",
"type_info": "Int2"
},
{
"ordinal": 4,
"name": "code_challenge",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Int8",
"Text",
"Int2",
"Text"
]
},
"nullable": [
false,
false,
false,
true,
true
]
}
},
"34e61467c9d30fa18f1f96990358f87aeb4ffa5fe0364fe499f6132017e0f20b": {
"query": "\n INSERT INTO oauth2_sessions \n (user_session_id, client_id, scope, state, nonce)\n VALUES\n ($1, $2, $3, $4, $5)\n RETURNING\n id, user_session_id, client_id, scope, state, nonce, created_at, updated_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "user_session_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "client_id",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "scope",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "state",
"type_info": "Text"
},
{
"ordinal": 5,
"name": "nonce",
"type_info": "Text"
},
{
"ordinal": 6,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 7,
"name": "updated_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Int8",
"Text",
"Text",
"Text",
"Text"
]
},
"nullable": [
false,
true,
false,
false,
true,
true,
false,
false
]
}
},
"35bedaa6fdf7ac91d54b458b4637f2182c2f82be3e2f80cd2db934ee279a7f2a": {
"query": "\n SELECT id, username\n FROM users\n WHERE id = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "username",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false,
false
]
}
},
"4f925a277d73df779360f81e0cf5d7983b50ebe744f461559dd561b7e36c20d4": {
"query": "\n SELECT\n s.id,\n u.id as user_id,\n u.username,\n s.active,\n s.created_at,\n a.created_at as \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u \n ON s.user_id = u.id\n LEFT JOIN user_session_authentications a\n ON a.session_id = s.id\n WHERE s.id = $1 AND s.active\n ORDER BY a.created_at DESC\n LIMIT 1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "user_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "username",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "active",
"type_info": "Bool"
},
{
"ordinal": 4,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 5,
"name": "last_authd_at?",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false,
false,
false,
false,
false,
false
]
}
},
"9ba45ab114b656105cc46b0c10fb05769860fcdc05eaf54d6225640fb914dab9": {
"query": "\n INSERT INTO user_session_authentications (session_id)\n VALUES ($1)\n RETURNING created_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "created_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false
]
}
},
"a09dfe1019110f2ec6eba0d35bafa467ab4b7980dd8b556826f03863f8edb0ab": {
"query": "UPDATE user_sessions SET active = FALSE WHERE id = $1",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": []
}
},
"a552eee8a8e5ffdee4d4789c634851bd64780dfe730807aac20142d7cd643814": {
"query": "\n SELECT u.hashed_password\n FROM user_sessions s\n INNER JOIN users u\n ON u.id = s.user_id \n WHERE s.id = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "hashed_password",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false
]
}
},
"f9a09ff53b6f221649f4f050e3d5ade114f852ddf50a78610a6c0ef0689af681": {
"query": "\n INSERT INTO users (username, hashed_password)\n VALUES ($1, $2)\n RETURNING id\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
}
],
"parameters": {
"Left": [
"Text",
"Text"
]
},
"nullable": [
false
]
}
}
}
@@ -15,12 +15,13 @@
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use thiserror::Error;
use url::Url;
use super::ConfigurationSection;
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct OAuth2ClientConfig {
pub client_id: String,
@@ -28,11 +29,37 @@ pub struct OAuth2ClientConfig {
pub redirect_uris: Option<Vec<Url>>,
}
#[derive(Debug, Error)]
#[error("Invalid redirect URI")]
pub struct InvalidRedirectUriError;
impl OAuth2ClientConfig {
pub fn resolve_redirect_uri<'a>(
&'a self,
suggested_uri: &'a Option<Url>,
) -> Result<&'a Url, InvalidRedirectUriError> {
match (suggested_uri, &self.redirect_uris) {
(None, None) => Err(InvalidRedirectUriError),
(None, Some(redirect_uris)) => {
redirect_uris.iter().next().ok_or(InvalidRedirectUriError)
}
(Some(suggested_uri), None) => Ok(suggested_uri),
(Some(suggested_uri), Some(redirect_uris)) => {
if redirect_uris.contains(suggested_uri) {
Ok(suggested_uri)
} else {
Err(InvalidRedirectUriError)
}
}
}
}
}
fn default_oauth2_issuer() -> Url {
"http://[::]:8080".parse().unwrap()
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct OAuth2Config {
#[serde(default = "default_oauth2_issuer")]
pub issuer: Url,
@@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use hyper::header::CONTENT_TYPE;
use mime::TEXT_PLAIN;
use sqlx::PgPool;
use tracing::{info_span, Instrument};
use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
use warp::{filters::BoxedFilter, reply::with_header, Filter, Rejection, Reply};
use crate::{errors::WrapError, filters::with_pool};
@@ -34,5 +36,5 @@ async fn get(pool: PgPool) -> Result<impl Reply, Rejection> {
.await
.wrap_error()?;
Ok(Box::new("ok"))
Ok(with_header("ok", CONTENT_TYPE, TEXT_PLAIN.to_string()))
}
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use hyper::StatusCode;
use sqlx::PgPool;
use warp::{filters::BoxedFilter, Filter};
@@ -30,9 +29,9 @@ pub fn root(
config: &RootConfig,
) -> BoxedFilter<(impl warp::Reply,)> {
health(pool)
.or(oauth2(&config.oauth2))
.or(oauth2(pool, &config.oauth2, &config.cookies))
.or(views(pool, templates, &config.csrf, &config.cookies))
.or(warp::get().map(|| StatusCode::NOT_FOUND))
//.or(warp::get().map(|| StatusCode::NOT_FOUND)) <- This messes up the error reporting
.with(warp::log(module_path!()))
.boxed()
}
@@ -12,13 +12,105 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use oauth2_types::requests::AuthorizationRequest;
use tide::{Body, Request, Response};
use data_encoding::BASE64URL_NOPAD;
use itertools::Itertools;
use oauth2_types::{
pkce,
requests::{AuthorizationRequest, ResponseType},
};
use serde::Deserialize;
use sqlx::PgPool;
use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
use crate::state::State;
use crate::{
config::{CookiesConfig, OAuth2ClientConfig, OAuth2Config},
errors::WrapError,
filters::{session::with_optional_session, with_pool},
storage::{oauth2::start_session, SessionInfo},
};
pub async fn get(req: Request<State>) -> tide::Result {
let params: AuthorizationRequest = req.query()?;
let body = Body::from_json(&params)?;
Ok(Response::builder(200).body(body).build())
#[derive(Deserialize)]
struct Params {
#[serde(flatten)]
auth: AuthorizationRequest,
#[serde(flatten)]
pkce: Option<pkce::Request>,
}
pub fn filter(
pool: &PgPool,
oauth2_config: &OAuth2Config,
cookies_config: &CookiesConfig,
) -> BoxedFilter<(impl Reply,)> {
let clients = oauth2_config.clients.clone();
warp::get()
.and(warp::path!("oauth2" / "authorize"))
.map(move || clients.clone())
.and(warp::query())
.and(with_optional_session(pool, cookies_config))
.and(with_pool(pool))
.and_then(get)
.boxed()
}
async fn get(
clients: Vec<OAuth2ClientConfig>,
params: Params,
maybe_session: Option<SessionInfo>,
pool: PgPool,
) -> Result<impl Reply, Rejection> {
// First, find out what client it is
let client = clients
.into_iter()
.find(|client| client.client_id == params.auth.client_id)
.ok_or_else(|| anyhow::anyhow!("could not find client"))
.wrap_error()?;
// Then, figure out the redirect URI
let redirect_uri = client
.resolve_redirect_uri(&params.auth.redirect_uri)
.wrap_error()?;
// Start a DB transaction
let mut txn = pool.begin().await.wrap_error()?;
let maybe_session_id = maybe_session.as_ref().map(SessionInfo::key);
let scope: String = {
let it = params.auth.scope.iter().map(ToString::to_string);
Itertools::intersperse(it, " ".to_string()).collect()
};
let oauth2_session = start_session(
&mut txn,
maybe_session_id,
&client.client_id,
&scope,
params.auth.state.as_deref(),
params.auth.nonce.as_deref(),
)
.await
.wrap_error()?;
let code = if params.auth.response_type.contains(&ResponseType::Code) {
// 192bit random bytes encoded in base64, which gives a 32 character code
let code: [u8; 24] = rand::random();
let code = BASE64URL_NOPAD.encode(&code);
Some(
oauth2_session
.add_code(&mut txn, &code, &params.pkce)
.await
.wrap_error()?,
)
} else {
None
};
txn.commit().await.wrap_error()?;
Ok(warp::reply::json(&serde_json::json!({
"session": oauth2_session,
"code": code,
"redirect_uri": redirect_uri,
})))
}
@@ -12,15 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use warp::{filters::BoxedFilter, Reply};
use sqlx::PgPool;
use warp::{filters::BoxedFilter, Filter, Reply};
use crate::config::OAuth2Config;
use crate::config::{CookiesConfig, OAuth2Config};
// pub mod authorization;
mod authorization;
mod discovery;
use self::discovery::filter as discovery;
use self::{authorization::filter as authorization, discovery::filter as discovery};
pub fn filter(config: &OAuth2Config) -> BoxedFilter<(impl Reply,)> {
discovery(config)
pub fn filter(
pool: &PgPool,
oauth2_config: &OAuth2Config,
cookies_config: &CookiesConfig,
) -> BoxedFilter<(impl Reply,)> {
discovery(oauth2_config)
.or(authorization(pool, oauth2_config, cookies_config))
.boxed()
}
+2 -1
View File
@@ -16,7 +16,8 @@
#![deny(clippy::all)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::unused_async)]
#![allow(clippy::unused_async)] // Some warp filters need that
#![allow(clippy::used_underscore_binding)] // This is needed by sqlx macros
use anyhow::Context;
use clap::Clap;
@@ -14,6 +14,7 @@
use sqlx::migrate::Migrator;
pub mod oauth2;
mod user;
pub use self::user::{login, lookup_active_session, register_user, SessionInfo, User};
@@ -0,0 +1,111 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use chrono::{DateTime, Utc};
use oauth2_types::pkce;
use serde::Serialize;
use sqlx::{Executor, FromRow, Postgres};
#[derive(FromRow, Serialize)]
pub struct OAuth2Session {
id: i64,
user_session_id: Option<i64>,
client_id: String,
scope: String,
state: Option<String>,
nonce: Option<String>,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
}
impl OAuth2Session {
pub async fn add_code<'e>(
&self,
executor: impl Executor<'e, Database = Postgres>,
code: &str,
code_challenge: &Option<pkce::Request>,
) -> anyhow::Result<OAuth2Code> {
add_code(executor, self.id, code, code_challenge).await
}
}
pub async fn start_session(
executor: impl Executor<'_, Database = Postgres>,
optional_session_id: Option<i64>,
client_id: &str,
scope: &str,
state: Option<&str>,
nonce: Option<&str>,
) -> anyhow::Result<OAuth2Session> {
sqlx::query_as!(
OAuth2Session,
r#"
INSERT INTO oauth2_sessions
(user_session_id, client_id, scope, state, nonce)
VALUES
($1, $2, $3, $4, $5)
RETURNING
id, user_session_id, client_id, scope, state, nonce, created_at, updated_at
"#,
optional_session_id,
client_id,
scope,
state,
nonce,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 session")
}
#[derive(FromRow, Serialize)]
pub struct OAuth2Code {
id: i64,
oauth2_session_id: i64,
code: String,
code_challenge: Option<String>,
code_challenge_method: Option<i16>,
}
pub async fn add_code(
executor: impl Executor<'_, Database = Postgres>,
oauth2_session_id: i64,
code: &str,
code_challenge: &Option<pkce::Request>,
) -> anyhow::Result<OAuth2Code> {
let code_challenge_method = code_challenge
.as_ref()
.map(|c| c.code_challenge_method as i16);
let code_challenge = code_challenge.as_ref().map(|c| &c.code_challenge);
sqlx::query_as!(
OAuth2Code,
r#"
INSERT INTO oauth2_codes
(oauth2_session_id, code, code_challenge_method, code_challenge)
VALUES
($1, $2, $3, $4)
RETURNING
id, oauth2_session_id, code, code_challenge_method, code_challenge
"#,
oauth2_session_id,
code,
code_challenge_method,
code_challenge,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 authorization code")
}
@@ -74,7 +74,8 @@ pub async fn lookup_active_session(
executor: impl Executor<'_, Database = Postgres>,
id: i64,
) -> anyhow::Result<SessionInfo> {
sqlx::query_as(
sqlx::query_as!(
SessionInfo,
r#"
SELECT
s.id,
@@ -82,7 +83,7 @@ pub async fn lookup_active_session(
u.username,
s.active,
s.created_at,
a.created_at as last_authd_at
a.created_at as "last_authd_at?"
FROM user_sessions s
INNER JOIN users u
ON s.user_id = u.id
@@ -92,8 +93,8 @@ pub async fn lookup_active_session(
ORDER BY a.created_at DESC
LIMIT 1
"#,
id,
)
.bind(id)
.fetch_one(executor)
.await
.context("could not fetch session")
@@ -131,7 +132,7 @@ pub async fn authenticate_session(
password: &str,
) -> anyhow::Result<DateTime<Utc>> {
// First, fetch the hashed password from the user associated with that session
let hashed_password: String = sqlx::query_scalar(
let hashed_password: String = sqlx::query_scalar!(
r#"
SELECT u.hashed_password
FROM user_sessions s
@@ -139,8 +140,8 @@ pub async fn authenticate_session(
ON u.id = s.user_id
WHERE s.id = $1
"#,
session_id,
)
.bind(session_id)
.fetch_one(txn.borrow_mut())
.await
.context("could not fetch user password hash")?;
@@ -151,14 +152,14 @@ pub async fn authenticate_session(
hasher.verify_password(&[&context], &password)?;
// That went well, let's insert the auth info
let created_at: DateTime<Utc> = sqlx::query_scalar(
let created_at: DateTime<Utc> = sqlx::query_scalar!(
r#"
INSERT INTO user_session_authentications (session_id)
VALUES ($1)
RETURNING created_at
"#,
session_id,
)
.bind(session_id)
.fetch_one(txn.borrow_mut())
.await
.context("could not save session auth")?;
@@ -175,15 +176,15 @@ pub async fn register_user(
let salt = SaltString::generate(&mut OsRng);
let hashed_password = PasswordHash::generate(phf, password, salt.as_str())?;
let id: i64 = sqlx::query_scalar(
let id: i64 = sqlx::query_scalar!(
r#"
INSERT INTO users (username, hashed_password)
VALUES ($1, $2)
RETURNING id
"#,
username,
hashed_password.to_string(),
)
.bind(&username)
.bind(&hashed_password.to_string())
.fetch_one(executor)
.instrument(info_span!("Register user"))
.await
@@ -199,8 +200,7 @@ pub async fn end_session(
executor: impl Executor<'_, Database = Postgres>,
id: i64,
) -> anyhow::Result<()> {
let res = sqlx::query("UPDATE user_sessions SET active = FALSE WHERE id = $1")
.bind(&id)
let res = sqlx::query!("UPDATE user_sessions SET active = FALSE WHERE id = $1", id)
.execute(executor)
.instrument(info_span!("End session"))
.await
@@ -218,14 +218,15 @@ pub async fn lookup_user_by_id(
executor: impl Executor<'_, Database = Postgres>,
id: i64,
) -> anyhow::Result<User> {
sqlx::query_as(
sqlx::query_as!(
User,
r#"
SELECT id, username
FROM users
WHERE id = $1
"#,
id
)
.bind(&id)
.fetch_one(executor)
.instrument(info_span!("Fetch user"))
.await
@@ -236,14 +237,15 @@ pub async fn lookup_user_by_username(
executor: impl Executor<'_, Database = Postgres>,
username: &str,
) -> anyhow::Result<User> {
sqlx::query_as(
sqlx::query_as!(
User,
r#"
SELECT id, username
FROM users
WHERE username = $1
"#,
username,
)
.bind(&username)
.fetch_one(executor)
.instrument(info_span!("Fetch user"))
.await
+1
View File
@@ -14,3 +14,4 @@ url = { version = "2.2.2", features = ["serde"] }
parse-display = "0.5.1"
indoc = "1.0.3"
serde_with = "1.9.4"
sqlx = "0.5.5"
+5 -2
View File
@@ -14,6 +14,7 @@
use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize};
use sqlx::Type;
#[derive(
Debug,
@@ -28,15 +29,17 @@ use serde::{Deserialize, Serialize};
FromStr,
Serialize,
Deserialize,
Type,
)]
#[repr(i8)]
pub enum CodeChallengeMethod {
#[serde(rename = "plain")]
#[display("plain")]
Plain,
Plain = 0,
#[serde(rename = "S256")]
#[display("S256")]
S256,
S256 = 1,
}
#[derive(Serialize, Deserialize)]
+6 -6
View File
@@ -115,20 +115,20 @@ pub enum Prompt {
#[derive(Serialize, Deserialize)]
pub struct AuthorizationRequest {
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, ResponseType>")]
response_type: HashSet<ResponseType>,
pub response_type: HashSet<ResponseType>,
client_id: String,
pub client_id: String,
redirect_uri: Option<Url>,
pub redirect_uri: Option<Url>,
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, String>")]
scope: HashSet<String>,
pub scope: HashSet<String>,
state: Option<String>,
pub state: Option<String>,
response_mode: Option<ResponseMode>,
nonce: Option<String>,
pub nonce: Option<String>,
display: Option<Display>,