diff --git a/src/api/client/to_device.rs b/src/api/client/to_device.rs index 67dd83f91..17c9a2acf 100644 --- a/src/api/client/to_device.rs +++ b/src/api/client/to_device.rs @@ -1,14 +1,15 @@ use std::collections::BTreeMap; use axum::extract::State; -use conduwuit::{Error, Result}; +use conduwuit::{Err, Result}; use conduwuit_service::sending::EduBuf; use futures::StreamExt; use ruma::{ api::{ - client::{error::ErrorKind, to_device::send_event_to_device}, - federation::{self, transactions::edu::DirectDeviceContent}, + client::to_device::send_event_to_device, + federation::transactions::edu::{DirectDeviceContent, Edu}, }, + assign, to_device::DeviceIdOrAllDevices, }; @@ -31,14 +32,14 @@ pub(crate) async fn send_event_to_device_route( .await .is_ok() { - return Ok(send_event_to_device::v3::Response {}); + return Ok(send_event_to_device::v3::Response::new()); } for (target_user_id, map) in &body.messages { - for (target_device_id_maybe, event) in map { + for (target_device, event) in map { if !services.globals.user_is_local(target_user_id) { let mut map = BTreeMap::new(); - map.insert(target_device_id_maybe.clone(), event.clone()); + map.insert(target_device.clone(), event.clone()); let mut messages = BTreeMap::new(); messages.insert(target_user_id.clone(), map); let count = services.globals.next_count()?; @@ -46,12 +47,14 @@ pub(crate) async fn send_event_to_device_route( let mut buf = EduBuf::new(); serde_json::to_writer( &mut buf, - &federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent { - sender: sender_user.to_owned(), - ev_type: body.event_type.clone(), - message_id: count.to_string().into(), - messages, - }), + &Edu::DirectToDevice(assign!( + DirectDeviceContent::new( + sender_user.to_owned(), + body.event_type.clone(), + count.to_string().into(), + ), + { messages } + )), ) .expect("DirectToDevice EDU can be serialized"); @@ -64,11 +67,11 @@ pub(crate) async fn send_event_to_device_route( let event_type = &body.event_type.to_string(); - let event = event - .deserialize_as() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?; + let Ok(event) = event.deserialize_as() else { + return Err!(Request(InvalidParam("Failed to deserialize event body."))); + }; - match target_device_id_maybe { + match target_device { | DeviceIdOrAllDevices::DeviceId(target_device_id) => { services .users @@ -81,20 +84,22 @@ pub(crate) async fn send_event_to_device_route( ) .await; }, - | DeviceIdOrAllDevices::AllDevices => { let (event_type, event) = (&event_type, &event); services .users .all_device_ids(target_user_id) - .for_each(|target_device_id| { - services.users.add_to_device_event( - sender_user, - target_user_id, - target_device_id, - event_type, - event.clone(), - ) + .for_each(async |target_device_id| { + services + .users + .add_to_device_event( + sender_user, + target_user_id, + &target_device_id, + event_type, + event.clone(), + ) + .await }) .await; }, @@ -107,5 +112,5 @@ pub(crate) async fn send_event_to_device_route( .transactions .add_client_txnid(sender_user, sender_device, &body.txn_id, &[]); - Ok(send_event_to_device::v3::Response {}) + Ok(send_event_to_device::v3::Response::new()) }