Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 131 additions & 1 deletion src/test/spec/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ mod basic {
}

#[tokio::test(flavor = "multi_thread")]
async fn machine_4_reauthentication() -> anyhow::Result<()> {
async fn machine_4_1_reauthentication() -> anyhow::Result<()> {
let admin_client = Client::with_uri_str(&*MONGODB_URI).await?;

// Now set a failpoint for find with 391 error code
Expand Down Expand Up @@ -393,6 +393,136 @@ mod basic {
Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn machine_4_2_read_command_fails_if_reauth_fails() -> anyhow::Result<()> {
let call_count = Arc::new(Mutex::new(0));
let cb_call_count = call_count.clone();

let mut options = ClientOptions::parse(&*MONGODB_URI_SINGLE).await?;
let credential = Credential::builder()
.mechanism(AuthMechanism::MongoDbOidc)
.oidc_callback(oidc::Callback::machine(move |_| {
let call_count = cb_call_count.clone();
async move {
*call_count.lock().await += 1;
let access_token = if *call_count.lock().await == 1 {
get_access_token_test_user_1().await
} else {
"bad token".to_string()
};
Ok(oidc::IdpServerResponse::builder()
.access_token(access_token)
.build())
}
.boxed()
}))
.build();
options.credential = Some(credential);
let client = Client::with_options(options)?;
let collection = client.database("test").collection::<Document>("test");

collection.find_one(doc! {}).await?;

let fail_point =
FailPoint::fail_command(&["find"], FailPointMode::Times(1)).error_code(391);
let _guard = client.enable_fail_point(fail_point).await?;

collection.find_one(doc! {}).await.unwrap_err();

assert_eq!(*call_count.lock().await, 2);

Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn machine_4_3_write_command_fails_if_reauth_fails() -> anyhow::Result<()> {
let call_count = Arc::new(Mutex::new(0));
let cb_call_count = call_count.clone();

let mut options = ClientOptions::parse(&*MONGODB_URI_SINGLE).await?;
let credential = Credential::builder()
.mechanism(AuthMechanism::MongoDbOidc)
.oidc_callback(oidc::Callback::machine(move |_| {
let call_count = cb_call_count.clone();
async move {
*call_count.lock().await += 1;
let access_token = if *call_count.lock().await == 1 {
get_access_token_test_user_1().await
} else {
"bad token".to_string()
};
Ok(oidc::IdpServerResponse::builder()
.access_token(access_token)
.build())
}
.boxed()
}))
.build();
options.credential = Some(credential);
let client = Client::with_options(options)?;
let collection = client.database("test").collection::<Document>("test");

collection.insert_one(doc! { "x": 1 }).await?;

let fail_point =
FailPoint::fail_command(&["insert"], FailPointMode::Times(1)).error_code(391);
let _guard = client.enable_fail_point(fail_point).await?;

collection.insert_one(doc! { "y": 2 }).await.unwrap_err();

assert_eq!(*call_count.lock().await, 2);

Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn machine_4_4_speculative_auth_ignored_on_reauth() -> anyhow::Result<()> {
let call_count = Arc::new(Mutex::new(0));
let cb_call_count = call_count.clone();

let mut options = ClientOptions::parse(&*MONGODB_URI_SINGLE).await?;
let credential = Credential::builder()
.mechanism(AuthMechanism::MongoDbOidc)
.oidc_callback(oidc::Callback::machine(move |_| {
let call_count = cb_call_count.clone();
async move {
*call_count.lock().await += 1;
Ok(oidc::IdpServerResponse::builder()
.access_token(get_access_token_test_user_1().await)
.build())
}
.boxed()
}))
.build();
credential
.oidc_callback
.set_access_token(Some(get_access_token_test_user_1().await))
.await;
options.credential = Some(credential);
let client = Client::for_test().options(options).monitor_events().await;
let event_buffer = &client.events;
let collection = client.database("test").collection::<Document>("test");

collection.insert_one(doc! { "x": 1 }).await?;

assert_eq!(*call_count.lock().await, 0);
let sasl_start_events = event_buffer.get_command_started_events(&["saslStart"]);
assert!(sasl_start_events.is_empty());

let fail_point =
FailPoint::fail_command(&["insert"], FailPointMode::Times(1)).error_code(391);
let _guard = client.enable_fail_point(fail_point).await?;

collection.insert_one(doc! { "y": 2 }).await?;

assert_eq!(*call_count.lock().await, 1);
let _sasl_start_events = event_buffer.get_command_started_events(&["saslStart"]);
// TODO RUST-2176: unskip this assertion when saslStart events are emitted
// assert!(!sasl_start_events.is_empty());

Ok(())
}

// Human Callback tests
#[tokio::test]
async fn human_1_1_single_principal_implicit_username() -> anyhow::Result<()> {
Expand Down