diff --git a/src/service/firstrun/mod.rs b/src/service/firstrun/mod.rs index dfad73896..67dfcd426 100644 --- a/src/service/firstrun/mod.rs +++ b/src/service/firstrun/mod.rs @@ -231,7 +231,8 @@ pub fn print_first_run_banner(&self) { if self.services.config.suspend_on_register { eprintln!( - "{} Accounts created after yours will be suspended, as set in your configuration.", + "{} Accounts created after yours will be suspended, as set in your \ + configuration.", "Your account will not be suspended when you register.".green() ); } @@ -239,11 +240,15 @@ pub fn print_first_run_banner(&self) { if let Some(smtp) = &self.services.config.smtp { if smtp.require_email_for_registration || smtp.require_email_for_token_registration { eprintln!( - "{} Accounts created after yours may be required to provide an email address, as set in your configuration.", + "{} Accounts created after yours may be required to provide an email \ + address, as set in your configuration.", "You will not be asked for your email address when you register.".yellow(), ); } - eprintln!("If you wish to associate an email address with your account, you may do so after registration in your client's settings (if supported)."); + eprintln!( + "If you wish to associate an email address with your account, you may do so \ + after registration in your client's settings (if supported)." + ); } eprintln!( diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 26135ad31..2c62e511b 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -226,9 +226,23 @@ async fn continue_session( return Ok(Err(session.get().info.clone())); } - let completed = 'completed: { + let completed = { let UiaaSession { info, identity } = session.get_mut(); + let auth_type = auth.auth_type().expect("auth type should be set"); + + let flow_stages: Vec> = info + .flows + .iter() + .map(|flow| { + flow.stages + .iter() + .map(AuthType::as_str) + .map(ToOwned::to_owned) + .collect() + }) + .collect(); + let mut completed_stages: HashSet<_> = info .completed .iter() @@ -236,10 +250,16 @@ async fn continue_session( .map(ToOwned::to_owned) .collect(); - // If the provided stage hasn't already been completed, check it for completion - if !completed_stages - .contains(auth.auth_type().expect("auth type should be set").as_str()) + // Don't allow stages which aren't in any flows + if !flow_stages + .iter() + .any(|stages| stages.contains(auth_type.as_str())) { + return Err!(Request(InvalidParam("No flows include the supplied stage"))); + } + + // If the provided stage hasn't already been completed, check it for completion + if !completed_stages.contains(auth_type.as_str()) { match self.check_stage(auth, identity.clone()).await { | Ok((completed_stage, updated_identity)) => { info.auth_error = None; @@ -253,23 +273,10 @@ async fn continue_session( } } - // Check all flows to see if any of them succeeded - for flow in &info.flows { - let flow_stages = flow - .stages - .iter() - .map(AuthType::as_str) - .map(ToOwned::to_owned) - .collect(); - - if completed_stages.is_superset(&flow_stages) { - // All stages in this flow are completed - break 'completed true; - } - } - - // No flows had all their stages completed - break 'completed false; + // UIAA is completed if all stages in any flow are completed + flow_stages + .iter() + .any(|stages| completed_stages.is_superset(stages)) }; if completed {