diff --git a/src/api/client/profile.rs b/src/api/client/profile.rs index e8e1c2dc0..9b3ceea3b 100644 --- a/src/api/client/profile.rs +++ b/src/api/client/profile.rs @@ -96,26 +96,7 @@ async fn fetch_full_profile( ) -> Option> { // If the user exists locally, fetch their local profile if services.users.exists(user_id).await { - let mut profile = BTreeMap::new(); - - // Get displayname and avatar_url independently because `all_profile_keys` - // doesn't include them - for field in [ProfileFieldName::AvatarUrl, ProfileFieldName::DisplayName] { - let key = field.as_str().to_owned(); - - if let Some(value) = get_local_profile_field(services, user_id, field).await { - profile.insert(key, value.value().into_owned()); - } - } - - // Insert all other profile fields - let mut all_fields = services.users.all_profile_keys(user_id); - - while let Some((key, value)) = all_fields.next().await { - profile.insert(key, value); - } - - return Some(profile); + return Some(get_local_profile(services, user_id).await); } // Otherwise ask their homeserver @@ -188,7 +169,33 @@ async fn fetch_profile_field( } } -async fn get_local_profile_field( +pub(crate) async fn get_local_profile( + services: &Services, + user_id: &UserId, +) -> BTreeMap { + let mut profile = BTreeMap::new(); + + // Get displayname and avatar_url independently because `all_profile_keys` + // doesn't include them + for field in [ProfileFieldName::AvatarUrl, ProfileFieldName::DisplayName] { + let key = field.as_str().to_owned(); + + if let Some(value) = get_local_profile_field(services, user_id, field).await { + profile.insert(key, value.value().into_owned()); + } + } + + // Insert all other profile fields + let mut all_fields = services.users.all_profile_keys(user_id); + + while let Some((key, value)) = all_fields.next().await { + profile.insert(key, value); + } + + return profile; +} + +pub(crate) async fn get_local_profile_field( services: &Services, user_id: &UserId, field: ProfileFieldName, diff --git a/src/api/server/query.rs b/src/api/server/query.rs index 9124eee9c..ff19d8214 100644 --- a/src/api/server/query.rs +++ b/src/api/server/query.rs @@ -1,19 +1,16 @@ -use std::collections::BTreeMap; - use axum::extract::State; -use conduwuit::{Error, Result, err}; +use conduwuit::{Err, Result, err}; use futures::StreamExt; -use get_profile_information::v1::ProfileField; use rand::seq::SliceRandom; use ruma::{ OwnedServerName, - api::{ - client::error::ErrorKind, - federation::query::{get_profile_information, get_room_information}, - }, + api::federation::query::{get_profile_information, get_room_information}, }; -use crate::Ruma; +use crate::{ + Ruma, + client::{get_local_profile, get_local_profile_field}, +}; /// # `GET /_matrix/federation/v1/query/directory` /// @@ -33,7 +30,6 @@ pub(crate) async fn get_room_information_route( .rooms .state_cache .room_servers(&room_id) - .map(ToOwned::to_owned) .collect() .await; @@ -51,7 +47,7 @@ pub(crate) async fn get_room_information_route( servers.insert(0, services.globals.server_name().to_owned()); } - Ok(get_room_information::v1::Response { room_id, servers }) + Ok(get_room_information::v1::Response::new(room_id, servers)) } /// # `GET /_matrix/federation/v1/query/profile` @@ -67,57 +63,31 @@ pub(crate) async fn get_profile_information_route( .config .allow_inbound_profile_lookup_federation_requests { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Profile lookup over federation is not allowed on this homeserver.", - )); + return Err!(Request(Forbidden( + "Profile lookup over federation is not allowed on this homeserver." + ))); } if !services.globals.server_is_ours(body.user_id.server_name()) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User does not belong to this server.", - )); + return Err!(Request(InvalidParam("User does not belong to this server."))); } - let mut displayname = None; - let mut avatar_url = None; - let mut blurhash = None; - let mut custom_profile_fields = BTreeMap::new(); + let response = if let Some(field) = &body.field { + let mut response = get_profile_information::v1::Response::new(); - match &body.field { - | Some(ProfileField::DisplayName) => { - displayname = services.users.displayname(&body.user_id).await.ok(); - }, - | Some(ProfileField::AvatarUrl) => { - avatar_url = services.users.avatar_url(&body.user_id).await.ok(); - blurhash = services.users.blurhash(&body.user_id).await.ok(); - }, - | Some(custom_field) => { - if let Ok(value) = services - .users - .profile_key(&body.user_id, custom_field.as_str()) - .await - { - custom_profile_fields.insert(custom_field.to_string(), value); - } - }, - | None => { - displayname = services.users.displayname(&body.user_id).await.ok(); - avatar_url = services.users.avatar_url(&body.user_id).await.ok(); - blurhash = services.users.blurhash(&body.user_id).await.ok(); - custom_profile_fields = services - .users - .all_profile_keys(&body.user_id) - .collect() - .await; - }, - } + if let Some(value) = + get_local_profile_field(&services, &body.user_id, field.to_owned()).await + { + response.set(value.field_name().as_str().to_owned(), value.value().into_owned()); + } - Ok(get_profile_information::v1::Response { - displayname, - avatar_url, - blurhash, - custom_profile_fields, - }) + response + } else { + get_local_profile(&services, &body.user_id) + .await + .into_iter() + .collect() + }; + + Ok(response) }