Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ typed-builder = "0.20.0"
webpki-roots = "0.26"
zstd = { version = "0.11.2", optional = true }
macro_magic = "0.5.1"
rustversion = "1.0.20"

[dependencies.pbkdf2]
version = "0.11.0"
Expand Down
238 changes: 138 additions & 100 deletions src/client/session/action.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,67 @@ impl<'a> Action for StartTransaction<&'a mut ClientSession> {
}
}

macro_rules! convenient_run {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here is unchanged.

(
$session:expr,
$start_transaction:expr,
$callback:expr,
$abort_transaction:expr,
$commit_transaction:expr,
) => {{
let timeout = Duration::from_secs(120);
#[cfg(test)]
let timeout = $session.convenient_transaction_timeout.unwrap_or(timeout);
let start = Instant::now();

use crate::error::{TRANSIENT_TRANSACTION_ERROR, UNKNOWN_TRANSACTION_COMMIT_RESULT};

'transaction: loop {
$start_transaction?;
let ret = match $callback {
Ok(v) => v,
Err(e) => {
if matches!(
$session.transaction.state,
TransactionState::Starting | TransactionState::InProgress
) {
$abort_transaction?;
}
if e.contains_label(TRANSIENT_TRANSACTION_ERROR) && start.elapsed() < timeout {
continue 'transaction;
}
return Err(e);
}
};
if matches!(
$session.transaction.state,
TransactionState::None
| TransactionState::Aborted
| TransactionState::Committed { .. }
) {
return Ok(ret);
}
'commit: loop {
match $commit_transaction {
Ok(()) => return Ok(ret),
Err(e) => {
if e.is_max_time_ms_expired_error() || start.elapsed() >= timeout {
return Err(e);
}
if e.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT) {
continue 'commit;
}
if e.contains_label(TRANSIENT_TRANSACTION_ERROR) {
continue 'transaction;
}
return Err(e);
}
}
}
}
}};
}

impl StartTransaction<&mut ClientSession> {
/// Starts a transaction, runs the given callback, and commits or aborts the transaction.
/// Transient transaction errors will cause the callback or the commit to be retried;
Expand Down Expand Up @@ -146,66 +207,84 @@ impl StartTransaction<&mut ClientSession> {
/// # Ok(())
/// # }
/// ```
#[rustversion::attr(since(1.85), deprecated = "use and_run2")]
pub async fn and_run<R, C, F>(self, mut context: C, mut callback: F) -> Result<R>
where
F: for<'b> FnMut(&'b mut ClientSession, &'b mut C) -> BoxFuture<'b, Result<R>>,
{
let timeout = Duration::from_secs(120);
#[cfg(test)]
let timeout = self
.session
.convenient_transaction_timeout
.unwrap_or(timeout);
let start = Instant::now();

use crate::error::{TRANSIENT_TRANSACTION_ERROR, UNKNOWN_TRANSACTION_COMMIT_RESULT};
convenient_run!(
self.session,
self.session
.start_transaction()
.with_options(self.options.clone())
.await,
callback(self.session, &mut context).await,
self.session.abort_transaction().await,
self.session.commit_transaction().await,
)
}

'transaction: loop {
/// Starts a transaction, runs the given callback, and commits or aborts the transaction.
/// Transient transaction errors will cause the callback or the commit to be retried;
/// other errors will cause the transaction to be aborted and the error returned to the
/// caller. If the callback needs to provide its own error information, the
/// [`Error::custom`](crate::error::Error::custom) method can accept an arbitrary payload that
/// can be retrieved via [`Error::get_custom`](crate::error::Error::get_custom).
///
/// If a command inside the callback fails, it may cause the transaction on the server to be
/// aborted. This situation is normally handled transparently by the driver. However, if the
/// application does not return that error from the callback, the driver will not be able to
/// determine whether the transaction was aborted or not. The driver will then retry the
/// callback indefinitely. To avoid this situation, the application MUST NOT silently handle
/// errors within the callback. If the application needs to handle errors within the
/// callback, it MUST return them after doing so.
///
/// This version of the method uses an async closure, which means it's both more convenient and
/// avoids the lifetime issues of `and_run`, but is only available in Rust versions 1.85 and
/// above.
///
/// Because the callback can be repeatedly executed, code within the callback cannot consume
/// owned values, even values owned by the callback itself:
///
/// ```no_run
/// # use mongodb::{bson::{doc, Document}, error::Result, Client};
/// # use futures::FutureExt;
/// # async fn wrapper() -> Result<()> {
/// # let client = Client::with_uri_str("mongodb://example.com").await?;
/// # let mut session = client.start_session().await?;
/// let coll = client.database("mydb").collection::<Document>("mycoll");
/// let my_data = "my data".to_string();
/// // This works:
/// session.start_transaction().and_run2(
/// async move |session| {
/// coll.insert_one(doc! { "data": my_data.clone() }).session(session).await
/// }
/// ).await?;
/// /* This will not compile:
/// session.start_transaction().and_run2(
/// async move |session| {
/// coll.insert_one(doc! { "data": my_data }).session(session).await
/// }
/// ).await?;
/// */
/// # Ok(())
/// # }
/// ```
#[rustversion::since(1.85)]
pub async fn and_run2<R>(
self,
mut callback: impl AsyncFnMut(&mut ClientSession) -> Result<R>,
) -> Result<R> {
convenient_run!(
self.session,
self.session
.start_transaction()
.with_options(self.options.clone())
.await?;
let ret = match callback(self.session, &mut context).await {
Ok(v) => v,
Err(e) => {
if matches!(
self.session.transaction.state,
TransactionState::Starting | TransactionState::InProgress
) {
self.session.abort_transaction().await?;
}
if e.contains_label(TRANSIENT_TRANSACTION_ERROR) && start.elapsed() < timeout {
continue 'transaction;
}
return Err(e);
}
};
if matches!(
self.session.transaction.state,
TransactionState::None
| TransactionState::Aborted
| TransactionState::Committed { .. }
) {
return Ok(ret);
}
'commit: loop {
match self.session.commit_transaction().await {
Ok(()) => return Ok(ret),
Err(e) => {
if e.is_max_time_ms_expired_error() || start.elapsed() >= timeout {
return Err(e);
}
if e.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT) {
continue 'commit;
}
if e.contains_label(TRANSIENT_TRANSACTION_ERROR) {
continue 'transaction;
}
return Err(e);
}
}
}
}
.await,
callback(self.session).await,
self.session.abort_transaction().await,
self.session.commit_transaction().await,
)
}
}

Expand Down Expand Up @@ -238,57 +317,16 @@ impl StartTransaction<&mut crate::sync::ClientSession> {
where
F: for<'b> FnMut(&'b mut crate::sync::ClientSession) -> Result<R>,
{
let timeout = std::time::Duration::from_secs(120);
let start = std::time::Instant::now();

use crate::error::{TRANSIENT_TRANSACTION_ERROR, UNKNOWN_TRANSACTION_COMMIT_RESULT};

'transaction: loop {
convenient_run!(
self.session.async_client_session,
self.session
.start_transaction()
.with_options(self.options.clone())
.run()?;
let ret = match callback(self.session) {
Ok(v) => v,
Err(e) => {
if matches!(
self.session.async_client_session.transaction.state,
TransactionState::Starting | TransactionState::InProgress
) {
self.session.abort_transaction().run()?;
}
if e.contains_label(TRANSIENT_TRANSACTION_ERROR) && start.elapsed() < timeout {
continue 'transaction;
}
return Err(e);
}
};
if matches!(
self.session.async_client_session.transaction.state,
TransactionState::None
| TransactionState::Aborted
| TransactionState::Committed { .. }
) {
return Ok(ret);
}
'commit: loop {
match self.session.commit_transaction().run() {
Ok(()) => return Ok(ret),
Err(e) => {
if e.is_max_time_ms_expired_error() || start.elapsed() >= timeout {
return Err(e);
}
if e.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT) {
continue 'commit;
}
if e.contains_label(TRANSIENT_TRANSACTION_ERROR) {
continue 'transaction;
}
return Err(e);
}
}
}
}
.run(),
callback(self.session),
self.session.abort_transaction().run(),
self.session.commit_transaction().run(),
)
}
}

Expand Down
8 changes: 2 additions & 6 deletions src/test/documentation_examples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1675,7 +1675,6 @@ async fn change_streams_examples() -> Result<()> {

async fn convenient_transaction_examples() -> Result<()> {
use crate::ClientSession;
use futures::FutureExt;
if !transactions_supported().await {
log_uncaptured(
"skipping convenient transaction API examples due to no transaction support",
Expand Down Expand Up @@ -1734,12 +1733,9 @@ async fn convenient_transaction_examples() -> Result<()> {
// Step 2: Start a client session.
let mut session = client.start_session().await?;

// Step 3: Use and_run to start a transaction, execute the callback, and commit (or
// Step 3: Use and_run2 to start a transaction, execute the callback, and commit (or
// abort on error).
session
.start_transaction()
.and_run((), |session, _| callback(session).boxed())
.await?;
session.start_transaction().and_run2(callback).await?;

// End Transactions withTxn API Example 1

Expand Down
50 changes: 17 additions & 33 deletions src/test/spec/transactions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::time::Duration;

use futures_util::FutureExt;
use serde::{Deserialize, Serialize};

use crate::{
Expand Down Expand Up @@ -104,12 +103,9 @@ async fn convenient_api_custom_error() {
struct MyErr;
let result: Result<()> = session
.start_transaction()
.and_run(coll, |session, coll| {
async move {
coll.find_one(doc! {}).session(session).await?;
Err(Error::custom(MyErr))
}
.boxed()
.and_run2(async move |session| {
coll.find_one(doc! {}).session(session).await?;
Err(Error::custom(MyErr))
})
.await;

Expand All @@ -136,12 +132,9 @@ async fn convenient_api_returned_value() {

let value = session
.start_transaction()
.and_run(coll, |session, coll| {
async move {
coll.find_one(doc! {}).session(session).await?;
Ok(42)
}
.boxed()
.and_run2(async move |session| {
coll.find_one(doc! {}).session(session).await?;
Ok(42)
})
.await
.unwrap();
Expand All @@ -165,14 +158,11 @@ async fn convenient_api_retry_timeout_callback() {

let result: Result<()> = session
.start_transaction()
.and_run(coll, |session, coll| {
async move {
coll.find_one(doc! {}).session(session).await?;
let mut err = Error::custom(42);
err.add_label(TRANSIENT_TRANSACTION_ERROR);
Err(err)
}
.boxed()
.and_run2(async move |session| {
coll.find_one(doc! {}).session(session).await?;
let mut err = Error::custom(42);
err.add_label(TRANSIENT_TRANSACTION_ERROR);
Err(err)
})
.await;

Expand Down Expand Up @@ -210,12 +200,9 @@ async fn convenient_api_retry_timeout_commit_unknown() {

let result = session
.start_transaction()
.and_run(coll, |session, coll| {
async move {
coll.find_one(doc! {}).session(session).await?;
Ok(())
}
.boxed()
.and_run2(async move |session| {
coll.find_one(doc! {}).session(session).await?;
Ok(())
})
.await;

Expand Down Expand Up @@ -252,12 +239,9 @@ async fn convenient_api_retry_timeout_commit_transient() {

let result = session
.start_transaction()
.and_run(coll, |session, coll| {
async move {
coll.find_one(doc! {}).session(session).await?;
Ok(())
}
.boxed()
.and_run2(async move |session| {
coll.find_one(doc! {}).session(session).await?;
Ok(())
})
.await;

Expand Down
Loading