use std::{ collections::{BTreeMap, HashMap, HashSet}, net::IpAddr, time::{Duration, Instant}, }; use axum::extract::State; use axum_client_ip::ClientIp; use conduwuit::{ Err, Error, Result, debug, debug_warn, err, error, result::LogErr, state_res::lexicographical_topological_sort, trace, utils::{ IterStream, ReadyExt, millis_since_unix_epoch, stream::{BroadbandExt, TryBroadbandExt, automatic_width}, }, warn, }; use conduwuit_service::{ Services, sending::{EDU_LIMIT, PDU_LIMIT}, }; use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use http::StatusCode; use itertools::Itertools; use ruma::{ CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, ServerName, UInt, UserId, api::{ error::{ErrorKind, LimitExceededErrorData}, federation::transactions::{ edu::{ DeviceListUpdateContent, DirectDeviceContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap, SigningKeyUpdateContent, TypingContent, }, send_transaction_message, }, }, events::receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType}, int, serde::Raw, to_device::DeviceIdOrAllDevices, }; use service::transactions::{ FederationTxnState, TransactionError, TxnKey, WrappedTransactionResponse, }; use tokio::sync::watch::{Receiver, Sender}; use tracing::instrument; use crate::Ruma; type ResolvedMap = BTreeMap; type Pdu = (OwnedRoomId, OwnedEventId, CanonicalJsonObject); /// # `PUT /_matrix/federation/v1/send/{txnId}` /// /// Push EDUs and PDUs to this server. pub(crate) async fn send_transaction_message_route( State(services): State, ClientIp(client): ClientIp, body: Ruma, ) -> Result { if body.origin() != body.body.origin { return Err!(Request(Forbidden( "Not allowed to send transactions on behalf of other servers" ))); } if body.pdus.len() > PDU_LIMIT { return Err!(Request(Forbidden( "Not allowed to send more than {PDU_LIMIT} PDUs in one transaction" ))); } if body.edus.len() > EDU_LIMIT { return Err!(Request(Forbidden( "Not allowed to send more than {EDU_LIMIT} EDUs in one transaction" ))); } let txn_key = (body.origin().to_owned(), body.transaction_id.clone()); // Atomically check cache, join active, or start new transaction match services .transactions .get_or_start_federation_txn(txn_key.clone())? { | FederationTxnState::Cached(response) => { // Already responded Ok(response) }, | FederationTxnState::Active(receiver) => { // Another thread is processing wait_for_result(receiver).await }, | FederationTxnState::Started { receiver, sender } => { // We're the first, spawn the processing task services .server .runtime() .spawn(process_inbound_transaction(services, body, client, txn_key, sender)); // and wait for it wait_for_result(receiver).await }, } } async fn wait_for_result( mut recv: Receiver, ) -> Result { if tokio::time::timeout(Duration::from_secs(50), recv.changed()) .await .is_err() { // Took too long, return 429 to encourage the sender to try again return Err(Error::BadRequest( ErrorKind::LimitExceeded(LimitExceededErrorData::new()), "Transaction is being still being processed. Please try again later.", )); } let value = recv.borrow_and_update(); match value.clone() { | Some(Ok(response)) => Ok(response), | Some(Err(err)) => Err(transaction_error_to_response(&err)), | None => Err(Error::Request( ErrorKind::Unknown, "Transaction processing failed unexpectedly".into(), StatusCode::INTERNAL_SERVER_ERROR, )), } } #[instrument( skip_all, fields( id = ?body.transaction_id.as_str(), origin = ?body.origin() ) )] async fn process_inbound_transaction( services: crate::State, body: Ruma, client: IpAddr, txn_key: TxnKey, sender: Sender, ) { let txn_start_time = Instant::now(); let pdus = body .pdus .iter() .stream() .broad_then(|pdu| services.rooms.event_handler.parse_incoming_pdu(pdu)) .inspect_err(|e| warn!("Could not parse incoming PDU: {e}")) .ready_filter_map(Result::ok); let edus = body .edus .iter() .map(|edu| edu.json().get()) .map(serde_json::from_str) .filter_map(Result::ok) .stream(); debug!(pdus = body.pdus.len(), edus = body.edus.len(), "Processing transaction",); let results = match handle(&services, &client, body.origin(), pdus, edus).await { | Ok(results) => results, | Err(err) => { fail_federation_txn(services, &txn_key, &sender, err); return; }, }; for (id, result) in &results { if let Err(e) = result { if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) { debug_warn!("Incoming PDU failed {id}: {e:?}"); } } } debug!( pdus = body.pdus.len(), edus = body.edus.len(), elapsed = ?txn_start_time.elapsed(), "Finished processing transaction" ); let response = send_transaction_message::v1::Response::new( results .into_iter() .map(|(e, r)| (e, r.map_err(error::sanitized_message))) .collect(), ); services .transactions .finish_federation_txn(txn_key, sender, response); } /// Handles a failed federation transaction by sending the error through /// the channel and cleaning up the transaction state. This allows waiters to /// receive an appropriate error response. fn fail_federation_txn( services: crate::State, txn_key: &TxnKey, sender: &Sender, err: TransactionError, ) { debug!("Transaction failed: {err}"); // Remove from active state so the transaction can be retried services.transactions.remove_federation_txn(txn_key); // Send the error to any waiters if let Err(e) = sender.send(Some(Err(err))) { debug_warn!("Failed to send transaction error to receivers: {e}"); } } /// Converts a TransactionError into an appropriate HTTP error response. fn transaction_error_to_response(err: &TransactionError) -> Error { match err { | TransactionError::ShuttingDown => Error::Request( ErrorKind::Unknown, "Server is shutting down, please retry later".into(), StatusCode::SERVICE_UNAVAILABLE, ), } } async fn handle( services: &Services, client: &IpAddr, origin: &ServerName, pdus: impl Stream + Send, edus: impl Stream + Send, ) -> std::result::Result { // group pdus by room let pdus = pdus .collect() .map(|mut pdus: Vec<_>| { pdus.sort_by(|(room_a, ..), (room_b, ..)| room_a.cmp(room_b)); pdus.into_iter() .into_grouping_map_by(|(room_id, ..)| room_id.clone()) .collect() }) .await; // we can evaluate rooms concurrently let results: ResolvedMap = pdus .into_iter() .try_stream() .broad_and_then(|(room_id, pdus): (_, Vec<_>)| { handle_room(services, client, origin, room_id, pdus.into_iter()) .map_ok(Vec::into_iter) .map_ok(IterStream::try_stream) }) .try_flatten() .try_collect() .boxed() .await?; // evaluate edus after pdus, at least for now. edus.for_each_concurrent(automatic_width(), |edu| handle_edu(services, client, origin, edu)) .boxed() .await; Ok(results) } /// Attempts to build a localised directed acyclic graph out of the given PDUs, /// returning them in a topologically sorted order. /// /// This is used to attempt to process PDUs in an order that respects their /// dependencies, however it is ultimately the sender's responsibility to send /// them in a processable order, so this is just a best effort attempt. It does /// not account for power levels or other tie breaks. async fn build_local_dag( pdu_map: &HashMap, ) -> Result> { debug_assert!(pdu_map.len() >= 2, "needless call to build_local_dag with less than 2 PDUs"); let mut dag: HashMap> = HashMap::with_capacity(pdu_map.len()); let mut id_origin_ts: HashMap = HashMap::with_capacity(pdu_map.len()); for (event_id, value) in pdu_map { // We already checked that these properties are correct in parse_incoming_pdu, // so it's safe to unwrap here. // We also filter to remove any prev_events that are not in this pdu_map, as we // need to have at least one event with zero out degrees for the lexico-topo // sort below. If there are multiple events with omitted prevs, they will be // ordered by timestamp, then event ID. At that point though, it's unlikely to // matter. let prev_events = value .get("prev_events") .unwrap() .as_array() .unwrap() .iter() .map(|v| EventId::parse(v.as_str().unwrap()).unwrap()) .filter(|id| pdu_map.contains_key(id)) .collect(); dag.insert(event_id.clone(), prev_events); let origin_server_ts = value .get("origin_server_ts") .and_then(ruma::CanonicalJsonValue::as_integer) .unwrap_or_default(); id_origin_ts.insert(event_id.clone(), origin_server_ts); } debug!(count = dag.len(), "Sorting incoming events with partial graph"); lexicographical_topological_sort(&dag, &async |node_id| { // Note: we don't bother fetching power levels because that would massively slow // this function down. This is a best-effort attempt to order events correctly // for processing, however ultimately that should be the sender's job. let ts = id_origin_ts .get(&node_id) .copied() .unwrap_or_else(|| int!(0)) .to_string() .parse::() .ok() .and_then(UInt::new) .unwrap_or_default(); Ok((int!(0), MilliSecondsSinceUnixEpoch(ts))) }) .await .inspect(|sorted| { debug_assert_eq!( sorted.len(), pdu_map.len(), "Sorted graph was not the same size as the input graph" ); }) .map_err(|e| err!("failed to resolve local graph: {e}")) } async fn handle_room( services: &Services, _client: &IpAddr, origin: &ServerName, room_id: OwnedRoomId, pdus: impl Iterator + Send, ) -> std::result::Result, TransactionError> { let _room_lock = services .rooms .event_handler .mutex_federation .lock(room_id.as_str()) .await; let room_id = &room_id; let pdu_map: HashMap = pdus .into_iter() .map(|(_, event_id, value)| (event_id, value)) .collect(); // Try to sort PDUs by their dependencies, but fall back to arbitrary order on // failure (e.g., cycles). This is best-effort; proper ordering is the sender's // responsibility. let sorted_event_ids = if pdu_map.len() >= 2 { build_local_dag(&pdu_map).await.unwrap_or_else(|e| { debug_warn!("Failed to build local DAG for room {room_id}: {e}"); pdu_map.keys().cloned().collect() }) } else { pdu_map.keys().cloned().collect() }; let mut results = Vec::with_capacity(sorted_event_ids.len()); for event_id in sorted_event_ids { let value = pdu_map .get(&event_id) .expect("sorted event IDs must be from the original map") .clone(); services .server .check_running() .map_err(|_| TransactionError::ShuttingDown)?; let result = services .rooms .event_handler .handle_incoming_pdu(origin, room_id, &event_id, value, true) .await .map(|_| ()); results.push((event_id, result)); } Ok(results) } async fn handle_edu(services: &Services, client: &IpAddr, origin: &ServerName, edu: Edu) { match edu { | Edu::Presence(presence) if services.server.config.allow_incoming_presence => handle_edu_presence(services, client, origin, presence).await, | Edu::Receipt(receipt) if services.server.config.allow_incoming_read_receipts => handle_edu_receipt(services, client, origin, receipt).await, | Edu::Typing(typing) if services.server.config.allow_incoming_typing => handle_edu_typing(services, client, origin, typing).await, | Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await, | Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await, | Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await, | Edu::_Custom(ref _custom) => debug_warn!(?edu, "received custom/unknown EDU"), | _ => trace!(?edu, "skipped"), } } async fn handle_edu_presence( services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent, ) { presence .push .into_iter() .stream() .for_each_concurrent(automatic_width(), |update| { handle_edu_presence_update(services, origin, update) }) .await; } async fn handle_edu_presence_update( services: &Services, origin: &ServerName, update: PresenceUpdate, ) { if update.user_id.server_name() != origin { debug_warn!( %update.user_id, %origin, "received presence EDU for user not belonging to origin" ); return; } services .presence .set_presence( &update.user_id, &update.presence, Some(update.currently_active), Some(update.last_active_ago), update.status_msg.clone(), ) .await .log_err() .ok(); } async fn handle_edu_receipt( services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent, ) { receipt .receipts .into_iter() .stream() .for_each_concurrent(automatic_width(), |(room_id, room_updates)| { handle_edu_receipt_room(services, origin, room_id, room_updates) }) .await; } async fn handle_edu_receipt_room( services: &Services, origin: &ServerName, room_id: OwnedRoomId, room_updates: ReceiptMap, ) { if services .rooms .event_handler .acl_check(origin, &room_id) .await .is_err() { debug_warn!( %origin, %room_id, "received read receipt EDU from ACL'd server" ); return; } let room_id = &room_id; room_updates .read .into_iter() .stream() .for_each_concurrent(automatic_width(), |(user_id, user_updates)| async move { handle_edu_receipt_room_user(services, origin, room_id, &user_id, user_updates).await; }) .await; } async fn handle_edu_receipt_room_user( services: &Services, origin: &ServerName, room_id: &RoomId, user_id: &UserId, user_updates: ReceiptData, ) { if user_id.server_name() != origin { debug_warn!( %user_id, %origin, "received read receipt EDU for user not belonging to origin" ); return; } if !services .rooms .state_cache .server_in_room(origin, room_id) .await { debug_warn!( %user_id, %room_id, %origin, "received read receipt EDU from server who does not have a member in the room", ); return; } let data = &user_updates.data; user_updates .event_ids .into_iter() .stream() .for_each_concurrent(automatic_width(), |event_id| async move { let user_data = [(user_id.to_owned(), data.clone())]; let receipts = [(ReceiptType::Read, BTreeMap::from(user_data))]; let content = [(event_id.clone(), BTreeMap::from(receipts))]; services .rooms .read_receipt .readreceipt_update( user_id, room_id, &ReceiptEvent::new( room_id.to_owned(), ReceiptEventContent::from_iter(content), ), ) .await; }) .await; } async fn handle_edu_typing( services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent, ) { if typing.user_id.server_name() != origin { debug_warn!( %typing.user_id, %origin, "received typing EDU for user not belonging to origin" ); return; } if services .rooms .event_handler .acl_check(typing.user_id.server_name(), &typing.room_id) .await .is_err() { debug_warn!( %typing.user_id, %typing.room_id, %origin, "received typing EDU for ACL'd user's server" ); return; } if !services .rooms .state_cache .is_joined(&typing.user_id, &typing.room_id) .await { debug_warn!( %typing.user_id, %typing.room_id, %origin, "received typing EDU for user not in room" ); return; } if typing.typing { let secs = services.server.config.typing_federation_timeout_s; let timeout = millis_since_unix_epoch().saturating_add(secs.saturating_mul(1000)); services .rooms .typing .typing_add(&typing.user_id, &typing.room_id, timeout) .await .log_err() .ok(); } else { services .rooms .typing .typing_remove(&typing.user_id, &typing.room_id) .await .log_err() .ok(); } } async fn handle_edu_device_list_update( services: &Services, _client: &IpAddr, origin: &ServerName, content: DeviceListUpdateContent, ) { let DeviceListUpdateContent { user_id, .. } = content; if user_id.server_name() != origin { debug_warn!( %user_id, %origin, "received device list update EDU for user not belonging to origin" ); return; } services.users.mark_device_key_update(&user_id).await; } async fn handle_edu_direct_to_device( services: &Services, _client: &IpAddr, origin: &ServerName, content: DirectDeviceContent, ) { let DirectDeviceContent { ref sender, ref ev_type, ref message_id, messages, .. } = content; if sender.server_name() != origin { debug_warn!( %sender, %origin, "received direct to device EDU for user not belonging to origin" ); return; } // Check if this is a new transaction id if services .transactions .get_client_txn(sender, None, message_id) .await .is_ok() { return; } // process messages concurrently for different users let ev_type = ev_type.to_string(); messages .into_iter() .stream() .for_each_concurrent(automatic_width(), |(target_user_id, map)| { handle_edu_direct_to_device_user(services, target_user_id, sender, &ev_type, map) }) .await; // Save transaction id with empty data services .transactions .add_client_txnid(sender, None, message_id, &[]); } async fn handle_edu_direct_to_device_user( services: &Services, target_user_id: OwnedUserId, sender: &UserId, ev_type: &str, map: BTreeMap>, ) { for (target_device_id_maybe, event) in map { let Ok(event) = event .deserialize_as() .map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}"))))) else { continue; }; handle_edu_direct_to_device_event( services, &target_user_id, sender, target_device_id_maybe, ev_type, event, ) .await; } } async fn handle_edu_direct_to_device_event( services: &Services, target_user_id: &UserId, sender: &UserId, target_device_id_maybe: DeviceIdOrAllDevices, ev_type: &str, event: serde_json::Value, ) { match target_device_id_maybe { | DeviceIdOrAllDevices::DeviceId(ref target_device_id) => { services .users .add_to_device_event(sender, target_user_id, target_device_id, ev_type, event) .await; }, | DeviceIdOrAllDevices::AllDevices => { services .users .all_device_ids(target_user_id) .for_each(async |target_device_id| { services .users .add_to_device_event( sender, target_user_id, &target_device_id, ev_type, event.clone(), ) .await }) .await; }, } } async fn handle_edu_signing_key_update( services: &Services, _client: &IpAddr, origin: &ServerName, content: SigningKeyUpdateContent, ) { let SigningKeyUpdateContent { user_id, master_key, self_signing_key, .. } = content; if user_id.server_name() != origin { debug_warn!( %user_id, %origin, "received signing key update EDU from server that does not belong to user's server" ); return; } services .users .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true) .await .log_err() .ok(); }