mirror of
https://github.com/element-hq/matrix-authentication-service.git
synced 2026-05-25 22:54:19 +00:00
starting the oauth2 authorization flow
also enable compile-time validation of queries
This commit is contained in:
Generated
+19
@@ -643,6 +643,9 @@ name = "either"
|
||||
version = "1.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fake-simd"
|
||||
@@ -1120,6 +1123,15 @@ dependencies = [
|
||||
"nom 6.1.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "0.4.7"
|
||||
@@ -1230,12 +1242,14 @@ dependencies = [
|
||||
"figment",
|
||||
"headers",
|
||||
"hyper",
|
||||
"itertools",
|
||||
"mime",
|
||||
"oauth2-types",
|
||||
"password-hash",
|
||||
"rand 0.8.4",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"serde_yaml",
|
||||
"sqlx",
|
||||
@@ -1405,6 +1419,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"sqlx",
|
||||
"url",
|
||||
]
|
||||
|
||||
@@ -2002,6 +2017,7 @@ version = "1.0.66"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "336b10da19a12ad094b59d870ebde26a45402e5b470add4b5fd03c5048a32127"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
@@ -2239,9 +2255,12 @@ dependencies = [
|
||||
"either",
|
||||
"futures",
|
||||
"heck",
|
||||
"hex",
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"sqlx-core",
|
||||
"sqlx-rt",
|
||||
|
||||
@@ -28,7 +28,7 @@ hyper = { version = "0.14.11", features = ["full"] }
|
||||
tera = "1.12.1"
|
||||
|
||||
# Database access
|
||||
sqlx = { version = "0.5.5", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono"] }
|
||||
sqlx = { version = "0.5.5", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline"] }
|
||||
|
||||
# Various structure (de)serialization
|
||||
serde = { version = "1.0.126", features = ["derive"] }
|
||||
@@ -48,6 +48,7 @@ password-hash = { version = "0.2.2", features = ["std"] }
|
||||
data-encoding = "2.3.2"
|
||||
chrono = { version = "0.4.19", features = ["serde"] }
|
||||
url = { version = "2.2.2", features = ["serde"] }
|
||||
itertools = "0.10.1"
|
||||
mime = "0.3.16"
|
||||
rand = "0.8.4"
|
||||
bincode = "1.3.3"
|
||||
@@ -56,3 +57,4 @@ cookie = "0.15.1"
|
||||
chacha20poly1305 = { version = "0.8.1", features = ["std"] }
|
||||
|
||||
oauth2-types = { path = "../oauth2-types" }
|
||||
serde_json = "1.0.66"
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
DROP TRIGGER set_timestamp ON TABLE oauth2_sessions;
|
||||
DROP TABLE oauth2_codes;
|
||||
DROP TABLE oauth2_sessions;
|
||||
@@ -0,0 +1,41 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
CREATE TABLE oauth2_sessions (
|
||||
"id" BIGSERIAL PRIMARY KEY,
|
||||
"user_session_id" BIGINT REFERENCES user_sessions (id) ON DELETE CASCADE,
|
||||
"client_id" TEXT NOT NULL,
|
||||
"scope" TEXT NOT NULL,
|
||||
"state" TEXT,
|
||||
"nonce" TEXT,
|
||||
|
||||
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
|
||||
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TRIGGER set_timestamp
|
||||
BEFORE UPDATE ON oauth2_sessions
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE trigger_set_timestamp();
|
||||
|
||||
CREATE TABLE oauth2_codes (
|
||||
"id" BIGSERIAL PRIMARY KEY,
|
||||
"oauth2_session_id" BIGINT NOT NULL REFERENCES oauth2_sessions (id) ON DELETE CASCADE,
|
||||
"code" TEXT UNIQUE NOT NULL,
|
||||
"code_challenge_method" SMALLINT,
|
||||
"code_challenge" TEXT,
|
||||
|
||||
CHECK (("code_challenge" IS NULL AND "code_challenge_method" IS NULL)
|
||||
OR ("code_challenge" IS NOT NULL AND "code_challenge_method" IS NOT NULL))
|
||||
);
|
||||
@@ -0,0 +1,291 @@
|
||||
{
|
||||
"db": "PostgreSQL",
|
||||
"037ba804eabd0b4290d87d1de37054f358eb11397d3a8e4b69a81cdce0a178e0": {
|
||||
"query": "\n SELECT id, username\n FROM users\n WHERE username = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "username",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"138c3297a66107d8428ca10d04f9a4dd75faf9c1d3f84bcedd3b09f55dd84206": {
|
||||
"query": "\n INSERT INTO oauth2_codes\n (oauth2_session_id, code, code_challenge_method, code_challenge)\n VALUES\n ($1, $2, $3, $4)\n RETURNING\n id, oauth2_session_id, code, code_challenge_method, code_challenge\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "oauth2_session_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "code",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "code_challenge_method",
|
||||
"type_info": "Int2"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "code_challenge",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Text",
|
||||
"Int2",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true
|
||||
]
|
||||
}
|
||||
},
|
||||
"34e61467c9d30fa18f1f96990358f87aeb4ffa5fe0364fe499f6132017e0f20b": {
|
||||
"query": "\n INSERT INTO oauth2_sessions \n (user_session_id, client_id, scope, state, nonce)\n VALUES\n ($1, $2, $3, $4, $5)\n RETURNING\n id, user_session_id, client_id, scope, state, nonce, created_at, updated_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "user_session_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "client_id",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "scope",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "state",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "nonce",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "updated_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"35bedaa6fdf7ac91d54b458b4637f2182c2f82be3e2f80cd2db934ee279a7f2a": {
|
||||
"query": "\n SELECT id, username\n FROM users\n WHERE id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "username",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"4f925a277d73df779360f81e0cf5d7983b50ebe744f461559dd561b7e36c20d4": {
|
||||
"query": "\n SELECT\n s.id,\n u.id as user_id,\n u.username,\n s.active,\n s.created_at,\n a.created_at as \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u \n ON s.user_id = u.id\n LEFT JOIN user_session_authentications a\n ON a.session_id = s.id\n WHERE s.id = $1 AND s.active\n ORDER BY a.created_at DESC\n LIMIT 1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "user_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "username",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "active",
|
||||
"type_info": "Bool"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "last_authd_at?",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"9ba45ab114b656105cc46b0c10fb05769860fcdc05eaf54d6225640fb914dab9": {
|
||||
"query": "\n INSERT INTO user_session_authentications (session_id)\n VALUES ($1)\n RETURNING created_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"a09dfe1019110f2ec6eba0d35bafa467ab4b7980dd8b556826f03863f8edb0ab": {
|
||||
"query": "UPDATE user_sessions SET active = FALSE WHERE id = $1",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
}
|
||||
},
|
||||
"a552eee8a8e5ffdee4d4789c634851bd64780dfe730807aac20142d7cd643814": {
|
||||
"query": "\n SELECT u.hashed_password\n FROM user_sessions s\n INNER JOIN users u\n ON u.id = s.user_id \n WHERE s.id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "hashed_password",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"f9a09ff53b6f221649f4f050e3d5ade114f852ddf50a78610a6c0ef0689af681": {
|
||||
"query": "\n INSERT INTO users (username, hashed_password)\n VALUES ($1, $2)\n RETURNING id\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,12 +15,13 @@
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
use thiserror::Error;
|
||||
use url::Url;
|
||||
|
||||
use super::ConfigurationSection;
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct OAuth2ClientConfig {
|
||||
pub client_id: String,
|
||||
|
||||
@@ -28,11 +29,37 @@ pub struct OAuth2ClientConfig {
|
||||
pub redirect_uris: Option<Vec<Url>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Invalid redirect URI")]
|
||||
pub struct InvalidRedirectUriError;
|
||||
|
||||
impl OAuth2ClientConfig {
|
||||
pub fn resolve_redirect_uri<'a>(
|
||||
&'a self,
|
||||
suggested_uri: &'a Option<Url>,
|
||||
) -> Result<&'a Url, InvalidRedirectUriError> {
|
||||
match (suggested_uri, &self.redirect_uris) {
|
||||
(None, None) => Err(InvalidRedirectUriError),
|
||||
(None, Some(redirect_uris)) => {
|
||||
redirect_uris.iter().next().ok_or(InvalidRedirectUriError)
|
||||
}
|
||||
(Some(suggested_uri), None) => Ok(suggested_uri),
|
||||
(Some(suggested_uri), Some(redirect_uris)) => {
|
||||
if redirect_uris.contains(suggested_uri) {
|
||||
Ok(suggested_uri)
|
||||
} else {
|
||||
Err(InvalidRedirectUriError)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_oauth2_issuer() -> Url {
|
||||
"http://[::]:8080".parse().unwrap()
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct OAuth2Config {
|
||||
#[serde(default = "default_oauth2_issuer")]
|
||||
pub issuer: Url,
|
||||
|
||||
@@ -12,9 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use hyper::header::CONTENT_TYPE;
|
||||
use mime::TEXT_PLAIN;
|
||||
use sqlx::PgPool;
|
||||
use tracing::{info_span, Instrument};
|
||||
use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
|
||||
use warp::{filters::BoxedFilter, reply::with_header, Filter, Rejection, Reply};
|
||||
|
||||
use crate::{errors::WrapError, filters::with_pool};
|
||||
|
||||
@@ -34,5 +36,5 @@ async fn get(pool: PgPool) -> Result<impl Reply, Rejection> {
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
Ok(Box::new("ok"))
|
||||
Ok(with_header("ok", CONTENT_TYPE, TEXT_PLAIN.to_string()))
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use hyper::StatusCode;
|
||||
use sqlx::PgPool;
|
||||
use warp::{filters::BoxedFilter, Filter};
|
||||
|
||||
@@ -30,9 +29,9 @@ pub fn root(
|
||||
config: &RootConfig,
|
||||
) -> BoxedFilter<(impl warp::Reply,)> {
|
||||
health(pool)
|
||||
.or(oauth2(&config.oauth2))
|
||||
.or(oauth2(pool, &config.oauth2, &config.cookies))
|
||||
.or(views(pool, templates, &config.csrf, &config.cookies))
|
||||
.or(warp::get().map(|| StatusCode::NOT_FOUND))
|
||||
//.or(warp::get().map(|| StatusCode::NOT_FOUND)) <- This messes up the error reporting
|
||||
.with(warp::log(module_path!()))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
@@ -12,13 +12,105 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use oauth2_types::requests::AuthorizationRequest;
|
||||
use tide::{Body, Request, Response};
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
use itertools::Itertools;
|
||||
use oauth2_types::{
|
||||
pkce,
|
||||
requests::{AuthorizationRequest, ResponseType},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
|
||||
|
||||
use crate::state::State;
|
||||
use crate::{
|
||||
config::{CookiesConfig, OAuth2ClientConfig, OAuth2Config},
|
||||
errors::WrapError,
|
||||
filters::{session::with_optional_session, with_pool},
|
||||
storage::{oauth2::start_session, SessionInfo},
|
||||
};
|
||||
|
||||
pub async fn get(req: Request<State>) -> tide::Result {
|
||||
let params: AuthorizationRequest = req.query()?;
|
||||
let body = Body::from_json(¶ms)?;
|
||||
Ok(Response::builder(200).body(body).build())
|
||||
#[derive(Deserialize)]
|
||||
struct Params {
|
||||
#[serde(flatten)]
|
||||
auth: AuthorizationRequest,
|
||||
|
||||
#[serde(flatten)]
|
||||
pkce: Option<pkce::Request>,
|
||||
}
|
||||
|
||||
pub fn filter(
|
||||
pool: &PgPool,
|
||||
oauth2_config: &OAuth2Config,
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> BoxedFilter<(impl Reply,)> {
|
||||
let clients = oauth2_config.clients.clone();
|
||||
warp::get()
|
||||
.and(warp::path!("oauth2" / "authorize"))
|
||||
.map(move || clients.clone())
|
||||
.and(warp::query())
|
||||
.and(with_optional_session(pool, cookies_config))
|
||||
.and(with_pool(pool))
|
||||
.and_then(get)
|
||||
.boxed()
|
||||
}
|
||||
|
||||
async fn get(
|
||||
clients: Vec<OAuth2ClientConfig>,
|
||||
params: Params,
|
||||
maybe_session: Option<SessionInfo>,
|
||||
pool: PgPool,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
// First, find out what client it is
|
||||
let client = clients
|
||||
.into_iter()
|
||||
.find(|client| client.client_id == params.auth.client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("could not find client"))
|
||||
.wrap_error()?;
|
||||
|
||||
// Then, figure out the redirect URI
|
||||
let redirect_uri = client
|
||||
.resolve_redirect_uri(¶ms.auth.redirect_uri)
|
||||
.wrap_error()?;
|
||||
|
||||
// Start a DB transaction
|
||||
let mut txn = pool.begin().await.wrap_error()?;
|
||||
let maybe_session_id = maybe_session.as_ref().map(SessionInfo::key);
|
||||
|
||||
let scope: String = {
|
||||
let it = params.auth.scope.iter().map(ToString::to_string);
|
||||
Itertools::intersperse(it, " ".to_string()).collect()
|
||||
};
|
||||
|
||||
let oauth2_session = start_session(
|
||||
&mut txn,
|
||||
maybe_session_id,
|
||||
&client.client_id,
|
||||
&scope,
|
||||
params.auth.state.as_deref(),
|
||||
params.auth.nonce.as_deref(),
|
||||
)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
let code = if params.auth.response_type.contains(&ResponseType::Code) {
|
||||
// 192bit random bytes encoded in base64, which gives a 32 character code
|
||||
let code: [u8; 24] = rand::random();
|
||||
let code = BASE64URL_NOPAD.encode(&code);
|
||||
Some(
|
||||
oauth2_session
|
||||
.add_code(&mut txn, &code, ¶ms.pkce)
|
||||
.await
|
||||
.wrap_error()?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
txn.commit().await.wrap_error()?;
|
||||
|
||||
Ok(warp::reply::json(&serde_json::json!({
|
||||
"session": oauth2_session,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
})))
|
||||
}
|
||||
|
||||
@@ -12,15 +12,22 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use warp::{filters::BoxedFilter, Reply};
|
||||
use sqlx::PgPool;
|
||||
use warp::{filters::BoxedFilter, Filter, Reply};
|
||||
|
||||
use crate::config::OAuth2Config;
|
||||
use crate::config::{CookiesConfig, OAuth2Config};
|
||||
|
||||
// pub mod authorization;
|
||||
mod authorization;
|
||||
mod discovery;
|
||||
|
||||
use self::discovery::filter as discovery;
|
||||
use self::{authorization::filter as authorization, discovery::filter as discovery};
|
||||
|
||||
pub fn filter(config: &OAuth2Config) -> BoxedFilter<(impl Reply,)> {
|
||||
discovery(config)
|
||||
pub fn filter(
|
||||
pool: &PgPool,
|
||||
oauth2_config: &OAuth2Config,
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> BoxedFilter<(impl Reply,)> {
|
||||
discovery(oauth2_config)
|
||||
.or(authorization(pool, oauth2_config, cookies_config))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
@@ -16,7 +16,8 @@
|
||||
#![deny(clippy::all)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
#![allow(clippy::unused_async)]
|
||||
#![allow(clippy::unused_async)] // Some warp filters need that
|
||||
#![allow(clippy::used_underscore_binding)] // This is needed by sqlx macros
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::Clap;
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use sqlx::migrate::Migrator;
|
||||
|
||||
pub mod oauth2;
|
||||
mod user;
|
||||
|
||||
pub use self::user::{login, lookup_active_session, register_user, SessionInfo, User};
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use chrono::{DateTime, Utc};
|
||||
use oauth2_types::pkce;
|
||||
use serde::Serialize;
|
||||
use sqlx::{Executor, FromRow, Postgres};
|
||||
|
||||
#[derive(FromRow, Serialize)]
|
||||
pub struct OAuth2Session {
|
||||
id: i64,
|
||||
user_session_id: Option<i64>,
|
||||
client_id: String,
|
||||
scope: String,
|
||||
state: Option<String>,
|
||||
nonce: Option<String>,
|
||||
|
||||
created_at: DateTime<Utc>,
|
||||
updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl OAuth2Session {
|
||||
pub async fn add_code<'e>(
|
||||
&self,
|
||||
executor: impl Executor<'e, Database = Postgres>,
|
||||
code: &str,
|
||||
code_challenge: &Option<pkce::Request>,
|
||||
) -> anyhow::Result<OAuth2Code> {
|
||||
add_code(executor, self.id, code, code_challenge).await
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_session(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
optional_session_id: Option<i64>,
|
||||
client_id: &str,
|
||||
scope: &str,
|
||||
state: Option<&str>,
|
||||
nonce: Option<&str>,
|
||||
) -> anyhow::Result<OAuth2Session> {
|
||||
sqlx::query_as!(
|
||||
OAuth2Session,
|
||||
r#"
|
||||
INSERT INTO oauth2_sessions
|
||||
(user_session_id, client_id, scope, state, nonce)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5)
|
||||
RETURNING
|
||||
id, user_session_id, client_id, scope, state, nonce, created_at, updated_at
|
||||
"#,
|
||||
optional_session_id,
|
||||
client_id,
|
||||
scope,
|
||||
state,
|
||||
nonce,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not insert oauth2 session")
|
||||
}
|
||||
|
||||
#[derive(FromRow, Serialize)]
|
||||
pub struct OAuth2Code {
|
||||
id: i64,
|
||||
oauth2_session_id: i64,
|
||||
code: String,
|
||||
code_challenge: Option<String>,
|
||||
code_challenge_method: Option<i16>,
|
||||
}
|
||||
|
||||
pub async fn add_code(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
oauth2_session_id: i64,
|
||||
code: &str,
|
||||
code_challenge: &Option<pkce::Request>,
|
||||
) -> anyhow::Result<OAuth2Code> {
|
||||
let code_challenge_method = code_challenge
|
||||
.as_ref()
|
||||
.map(|c| c.code_challenge_method as i16);
|
||||
let code_challenge = code_challenge.as_ref().map(|c| &c.code_challenge);
|
||||
sqlx::query_as!(
|
||||
OAuth2Code,
|
||||
r#"
|
||||
INSERT INTO oauth2_codes
|
||||
(oauth2_session_id, code, code_challenge_method, code_challenge)
|
||||
VALUES
|
||||
($1, $2, $3, $4)
|
||||
RETURNING
|
||||
id, oauth2_session_id, code, code_challenge_method, code_challenge
|
||||
"#,
|
||||
oauth2_session_id,
|
||||
code,
|
||||
code_challenge_method,
|
||||
code_challenge,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not insert oauth2 authorization code")
|
||||
}
|
||||
@@ -74,7 +74,8 @@ pub async fn lookup_active_session(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
id: i64,
|
||||
) -> anyhow::Result<SessionInfo> {
|
||||
sqlx::query_as(
|
||||
sqlx::query_as!(
|
||||
SessionInfo,
|
||||
r#"
|
||||
SELECT
|
||||
s.id,
|
||||
@@ -82,7 +83,7 @@ pub async fn lookup_active_session(
|
||||
u.username,
|
||||
s.active,
|
||||
s.created_at,
|
||||
a.created_at as last_authd_at
|
||||
a.created_at as "last_authd_at?"
|
||||
FROM user_sessions s
|
||||
INNER JOIN users u
|
||||
ON s.user_id = u.id
|
||||
@@ -92,8 +93,8 @@ pub async fn lookup_active_session(
|
||||
ORDER BY a.created_at DESC
|
||||
LIMIT 1
|
||||
"#,
|
||||
id,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not fetch session")
|
||||
@@ -131,7 +132,7 @@ pub async fn authenticate_session(
|
||||
password: &str,
|
||||
) -> anyhow::Result<DateTime<Utc>> {
|
||||
// First, fetch the hashed password from the user associated with that session
|
||||
let hashed_password: String = sqlx::query_scalar(
|
||||
let hashed_password: String = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT u.hashed_password
|
||||
FROM user_sessions s
|
||||
@@ -139,8 +140,8 @@ pub async fn authenticate_session(
|
||||
ON u.id = s.user_id
|
||||
WHERE s.id = $1
|
||||
"#,
|
||||
session_id,
|
||||
)
|
||||
.bind(session_id)
|
||||
.fetch_one(txn.borrow_mut())
|
||||
.await
|
||||
.context("could not fetch user password hash")?;
|
||||
@@ -151,14 +152,14 @@ pub async fn authenticate_session(
|
||||
hasher.verify_password(&[&context], &password)?;
|
||||
|
||||
// That went well, let's insert the auth info
|
||||
let created_at: DateTime<Utc> = sqlx::query_scalar(
|
||||
let created_at: DateTime<Utc> = sqlx::query_scalar!(
|
||||
r#"
|
||||
INSERT INTO user_session_authentications (session_id)
|
||||
VALUES ($1)
|
||||
RETURNING created_at
|
||||
"#,
|
||||
session_id,
|
||||
)
|
||||
.bind(session_id)
|
||||
.fetch_one(txn.borrow_mut())
|
||||
.await
|
||||
.context("could not save session auth")?;
|
||||
@@ -175,15 +176,15 @@ pub async fn register_user(
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let hashed_password = PasswordHash::generate(phf, password, salt.as_str())?;
|
||||
|
||||
let id: i64 = sqlx::query_scalar(
|
||||
let id: i64 = sqlx::query_scalar!(
|
||||
r#"
|
||||
INSERT INTO users (username, hashed_password)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id
|
||||
"#,
|
||||
username,
|
||||
hashed_password.to_string(),
|
||||
)
|
||||
.bind(&username)
|
||||
.bind(&hashed_password.to_string())
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Register user"))
|
||||
.await
|
||||
@@ -199,8 +200,7 @@ pub async fn end_session(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
id: i64,
|
||||
) -> anyhow::Result<()> {
|
||||
let res = sqlx::query("UPDATE user_sessions SET active = FALSE WHERE id = $1")
|
||||
.bind(&id)
|
||||
let res = sqlx::query!("UPDATE user_sessions SET active = FALSE WHERE id = $1", id)
|
||||
.execute(executor)
|
||||
.instrument(info_span!("End session"))
|
||||
.await
|
||||
@@ -218,14 +218,15 @@ pub async fn lookup_user_by_id(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
id: i64,
|
||||
) -> anyhow::Result<User> {
|
||||
sqlx::query_as(
|
||||
sqlx::query_as!(
|
||||
User,
|
||||
r#"
|
||||
SELECT id, username
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
"#,
|
||||
id
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Fetch user"))
|
||||
.await
|
||||
@@ -236,14 +237,15 @@ pub async fn lookup_user_by_username(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
username: &str,
|
||||
) -> anyhow::Result<User> {
|
||||
sqlx::query_as(
|
||||
sqlx::query_as!(
|
||||
User,
|
||||
r#"
|
||||
SELECT id, username
|
||||
FROM users
|
||||
WHERE username = $1
|
||||
"#,
|
||||
username,
|
||||
)
|
||||
.bind(&username)
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Fetch user"))
|
||||
.await
|
||||
|
||||
@@ -14,3 +14,4 @@ url = { version = "2.2.2", features = ["serde"] }
|
||||
parse-display = "0.5.1"
|
||||
indoc = "1.0.3"
|
||||
serde_with = "1.9.4"
|
||||
sqlx = "0.5.5"
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use parse_display::{Display, FromStr};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::Type;
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
@@ -28,15 +29,17 @@ use serde::{Deserialize, Serialize};
|
||||
FromStr,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
Type,
|
||||
)]
|
||||
#[repr(i8)]
|
||||
pub enum CodeChallengeMethod {
|
||||
#[serde(rename = "plain")]
|
||||
#[display("plain")]
|
||||
Plain,
|
||||
Plain = 0,
|
||||
|
||||
#[serde(rename = "S256")]
|
||||
#[display("S256")]
|
||||
S256,
|
||||
S256 = 1,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
|
||||
@@ -115,20 +115,20 @@ pub enum Prompt {
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AuthorizationRequest {
|
||||
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, ResponseType>")]
|
||||
response_type: HashSet<ResponseType>,
|
||||
pub response_type: HashSet<ResponseType>,
|
||||
|
||||
client_id: String,
|
||||
pub client_id: String,
|
||||
|
||||
redirect_uri: Option<Url>,
|
||||
pub redirect_uri: Option<Url>,
|
||||
|
||||
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, String>")]
|
||||
scope: HashSet<String>,
|
||||
pub scope: HashSet<String>,
|
||||
|
||||
state: Option<String>,
|
||||
pub state: Option<String>,
|
||||
|
||||
response_mode: Option<ResponseMode>,
|
||||
|
||||
nonce: Option<String>,
|
||||
pub nonce: Option<String>,
|
||||
|
||||
display: Option<Display>,
|
||||
|
||||
|
||||
Reference in New Issue
Block a user