diff --git a/src/test/spec/oidc.rs b/src/test/spec/oidc.rs index a1988e71c..ac2036d2d 100644 --- a/src/test/spec/oidc.rs +++ b/src/test/spec/oidc.rs @@ -353,7 +353,7 @@ mod basic { } #[tokio::test(flavor = "multi_thread")] - async fn machine_4_reauthentication() -> anyhow::Result<()> { + async fn machine_4_1_reauthentication() -> anyhow::Result<()> { let admin_client = Client::with_uri_str(&*MONGODB_URI).await?; // Now set a failpoint for find with 391 error code @@ -393,6 +393,136 @@ mod basic { Ok(()) } + #[tokio::test(flavor = "multi_thread")] + async fn machine_4_2_read_command_fails_if_reauth_fails() -> anyhow::Result<()> { + let call_count = Arc::new(Mutex::new(0)); + let cb_call_count = call_count.clone(); + + let mut options = ClientOptions::parse(&*MONGODB_URI_SINGLE).await?; + let credential = Credential::builder() + .mechanism(AuthMechanism::MongoDbOidc) + .oidc_callback(oidc::Callback::machine(move |_| { + let call_count = cb_call_count.clone(); + async move { + *call_count.lock().await += 1; + let access_token = if *call_count.lock().await == 1 { + get_access_token_test_user_1().await + } else { + "bad token".to_string() + }; + Ok(oidc::IdpServerResponse::builder() + .access_token(access_token) + .build()) + } + .boxed() + })) + .build(); + options.credential = Some(credential); + let client = Client::with_options(options)?; + let collection = client.database("test").collection::("test"); + + collection.find_one(doc! {}).await?; + + let fail_point = + FailPoint::fail_command(&["find"], FailPointMode::Times(1)).error_code(391); + let _guard = client.enable_fail_point(fail_point).await?; + + collection.find_one(doc! {}).await.unwrap_err(); + + assert_eq!(*call_count.lock().await, 2); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn machine_4_3_write_command_fails_if_reauth_fails() -> anyhow::Result<()> { + let call_count = Arc::new(Mutex::new(0)); + let cb_call_count = call_count.clone(); + + let mut options = ClientOptions::parse(&*MONGODB_URI_SINGLE).await?; + let credential = Credential::builder() + .mechanism(AuthMechanism::MongoDbOidc) + .oidc_callback(oidc::Callback::machine(move |_| { + let call_count = cb_call_count.clone(); + async move { + *call_count.lock().await += 1; + let access_token = if *call_count.lock().await == 1 { + get_access_token_test_user_1().await + } else { + "bad token".to_string() + }; + Ok(oidc::IdpServerResponse::builder() + .access_token(access_token) + .build()) + } + .boxed() + })) + .build(); + options.credential = Some(credential); + let client = Client::with_options(options)?; + let collection = client.database("test").collection::("test"); + + collection.insert_one(doc! { "x": 1 }).await?; + + let fail_point = + FailPoint::fail_command(&["insert"], FailPointMode::Times(1)).error_code(391); + let _guard = client.enable_fail_point(fail_point).await?; + + collection.insert_one(doc! { "y": 2 }).await.unwrap_err(); + + assert_eq!(*call_count.lock().await, 2); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn machine_4_4_speculative_auth_ignored_on_reauth() -> anyhow::Result<()> { + let call_count = Arc::new(Mutex::new(0)); + let cb_call_count = call_count.clone(); + + let mut options = ClientOptions::parse(&*MONGODB_URI_SINGLE).await?; + let credential = Credential::builder() + .mechanism(AuthMechanism::MongoDbOidc) + .oidc_callback(oidc::Callback::machine(move |_| { + let call_count = cb_call_count.clone(); + async move { + *call_count.lock().await += 1; + Ok(oidc::IdpServerResponse::builder() + .access_token(get_access_token_test_user_1().await) + .build()) + } + .boxed() + })) + .build(); + credential + .oidc_callback + .set_access_token(Some(get_access_token_test_user_1().await)) + .await; + options.credential = Some(credential); + let client = Client::for_test().options(options).monitor_events().await; + let event_buffer = &client.events; + let collection = client.database("test").collection::("test"); + + collection.insert_one(doc! { "x": 1 }).await?; + + assert_eq!(*call_count.lock().await, 0); + let sasl_start_events = event_buffer.get_command_started_events(&["saslStart"]); + assert!(sasl_start_events.is_empty()); + + let fail_point = + FailPoint::fail_command(&["insert"], FailPointMode::Times(1)).error_code(391); + let _guard = client.enable_fail_point(fail_point).await?; + + collection.insert_one(doc! { "y": 2 }).await?; + + assert_eq!(*call_count.lock().await, 1); + let _sasl_start_events = event_buffer.get_command_started_events(&["saslStart"]); + // TODO RUST-2176: unskip this assertion when saslStart events are emitted + // assert!(!sasl_start_events.is_empty()); + + Ok(()) + } + // Human Callback tests #[tokio::test] async fn human_1_1_single_principal_implicit_username() -> anyhow::Result<()> {