Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions src/action/bulk_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,6 @@ where
}

async fn execute_inner(mut self) -> Result<R> {
#[cfg(feature = "in-use-encryption")]
if self.client.should_auto_encrypt().await {
use mongocrypt::error::{Error as EncryptionError, ErrorKind as EncryptionErrorKind};

let error = EncryptionError {
kind: EncryptionErrorKind::Client,
code: None,
message: Some(
"bulkWrite does not currently support automatic encryption".to_string(),
),
};
return Err(ErrorKind::Encryption(error).into());
}

resolve_write_concern_with_session!(
self.client,
self.options,
Expand All @@ -148,7 +134,8 @@ where
&self.models[total_attempted..],
total_attempted,
self.options.as_ref(),
);
)
.await;
let result = self
.client
.execute_operation::<BulkWriteOperation<R>>(
Expand Down
45 changes: 45 additions & 0 deletions src/bson_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ use crate::{
RawBsonRef,
RawDocumentBuf,
},
bson_compat::CStr,
checked::Checked,
cmap::Command,
error::{Error, ErrorKind, Result},
runtime::SyncLittleEndianRead,
};
Expand Down Expand Up @@ -246,6 +248,49 @@ pub(crate) fn get_or_prepend_id_field(doc: &mut RawDocumentBuf) -> Result<Bson>
}
}

/// A helper trait for working with collections of raw documents. This is useful for unifying
/// command-building implementations that conditionally construct either document sequences or a
/// single command document.
pub(crate) trait RawDocumentCollection: Default {
/// Calculates the total number of bytes that would be added to a collection of this type by the
/// given document.
fn bytes_added(index: usize, doc: &RawDocumentBuf) -> Result<usize>;

/// Adds the given document to the collection.
fn push(&mut self, doc: RawDocumentBuf);

/// Adds the collection of raw documents to the provided command.
fn add_to_command(self, identifier: &CStr, command: &mut Command);
}

impl RawDocumentCollection for Vec<RawDocumentBuf> {
fn bytes_added(_index: usize, doc: &RawDocumentBuf) -> Result<usize> {
Ok(doc.as_bytes().len())
}

fn push(&mut self, doc: RawDocumentBuf) {
self.push(doc);
}

fn add_to_command(self, identifier: &CStr, command: &mut Command) {
command.add_document_sequence(identifier, self);
}
}

impl RawDocumentCollection for RawArrayBuf {
fn bytes_added(index: usize, doc: &RawDocumentBuf) -> Result<usize> {
array_entry_size_bytes(index, doc.as_bytes().len())
}

fn push(&mut self, doc: RawDocumentBuf) {
self.push(doc);
}

fn add_to_command(self, identifier: &CStr, command: &mut Command) {
command.body.append(identifier, self);
}
}

#[cfg(test)]
mod test {
use crate::bson_util::num_decimal_digits;
Expand Down
3 changes: 2 additions & 1 deletion src/client/csfle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ impl ClientState {
.kms_providers(&opts.kms_providers.credentials_doc()?)?
.use_need_kms_credentials_state()
.retry_kms(true)?
.use_range_v2()?;
.use_range_v2()?
.use_need_mongo_collinfo_with_db_state();
if let Some(m) = &opts.schema_map {
builder = builder.schema_map(&crate::bson_compat::serialize_to_document(m)?)?;
}
Expand Down
33 changes: 18 additions & 15 deletions src/client/csfle/state_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ use std::{
time::Duration,
};

use crate::{
bson::{rawdoc, Document, RawDocument, RawDocumentBuf},
bson_compat::{cstr, CString},
};
use crate::bson::{rawdoc, Document, RawDocument, RawDocumentBuf};
use futures_util::{stream, TryStreamExt};
use mongocrypt::ctx::{Ctx, KmsCtx, KmsProviderType, State};
use rayon::ThreadPool;
Expand Down Expand Up @@ -95,6 +92,13 @@ impl CryptExecutor {
self.mongocryptd_client.is_some()
}

fn metadata_client(&self, state: &State) -> Result<Client> {
self.metadata_client
.as_ref()
.and_then(|w| w.upgrade())
.ok_or_else(|| Error::internal(format!("metadata client required for {state:?}")))
}

pub(crate) async fn run_ctx(&self, ctx: Ctx, db: Option<&str>) -> Result<RawDocumentBuf> {
let mut result = None;
// This needs to be a `Result` so that the `Ctx` can be temporarily owned by the processing
Expand All @@ -104,16 +108,10 @@ impl CryptExecutor {
loop {
let state = result_ref(&ctx)?.state()?;
match state {
State::NeedMongoCollinfo => {
State::NeedMongoCollinfo | State::NeedMongoCollinfoWithDb => {
let ctx = result_mut(&mut ctx)?;
let filter = raw_to_doc(ctx.mongo_op()?)?;
let metadata_client = self
.metadata_client
.as_ref()
.and_then(|w| w.upgrade())
.ok_or_else(|| {
Error::internal("metadata_client required for NeedMongoCollinfo state")
})?;
let metadata_client = self.metadata_client(&state)?;
let db = metadata_client.database(db.as_ref().ok_or_else(|| {
Error::internal("db required for NeedMongoCollinfo state")
})?);
Expand Down Expand Up @@ -245,7 +243,9 @@ impl CryptExecutor {
continue;
}

let prov_name: CString = provider.as_string().try_into()?;
#[cfg(any(feature = "aws-auth", feature = "azure-kms"))]
let prov_name: crate::bson_compat::CString =
provider.as_string().try_into()?;
match provider.provider_type() {
KmsProviderType::Aws => {
#[cfg(feature = "aws-auth")]
Expand All @@ -263,7 +263,10 @@ impl CryptExecutor {
"secretAccessKey": aws_creds.secret_access_key().to_string(),
};
if let Some(token) = aws_creds.session_token() {
creds.append(cstr!("sessionToken"), token);
creds.append(
crate::bson_compat::cstr!("sessionToken"),
token,
);
}
kms_providers.append(prov_name, creds);
}
Expand Down Expand Up @@ -326,7 +329,7 @@ impl CryptExecutor {
.await
.map_err(|e| kms_error(e.to_string()))?;
kms_providers.append(
cstr!("gcp"),
crate::bson_compat::cstr!("gcp"),
rawdoc! { "accessToken": response.access_token },
);
}
Expand Down
50 changes: 30 additions & 20 deletions src/client/options/bulk_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@ use typed_builder::TypedBuilder;

use crate::{
bson::{rawdoc, Array, Bson, Document, RawDocumentBuf},
bson_compat::cstr,
bson_util::{get_or_prepend_id_field, replacement_document_check, update_document_check},
error::Result,
bson_compat::{cstr, serialize_to_raw_document_buf},
bson_util::{
extend_raw_document_buf,
get_or_prepend_id_field,
replacement_document_check,
update_document_check,
},
error::{Error, Result},
options::{UpdateModifications, WriteConcern},
serde_util::{serialize_bool_or_true, write_concern_is_empty},
Collection,
Expand Down Expand Up @@ -371,9 +376,17 @@ impl WriteModel {
}
}

/// Returns the operation-specific fields that should be included in this model's entry in the
/// ops array. Also returns an inserted ID if this is an insert operation.
pub(crate) fn get_ops_document_contents(&self) -> Result<(RawDocumentBuf, Option<Bson>)> {
/// Constructs the ops document for this write model given the nsInfo array index.
pub(crate) fn get_ops_document(
&self,
ns_info_index: usize,
) -> Result<(RawDocumentBuf, Option<Bson>)> {
// The maximum number of namespaces allowed in a bulkWrite command is much lower than
// i32::MAX, so this should never fail.
let index = i32::try_from(ns_info_index)
.map_err(|_| Error::internal("nsInfo index exceeds i32::MAX"))?;
let mut ops_document = rawdoc! { self.operation_name(): index };

if let Self::UpdateOne(UpdateOneModel { update, .. })
| Self::UpdateMany(UpdateManyModel { update, .. }) = self
{
Expand All @@ -384,22 +397,19 @@ impl WriteModel {
replacement_document_check(replacement)?;
}

let (mut model_document, inserted_id) = match self {
Self::InsertOne(model) => {
let mut insert_document = RawDocumentBuf::try_from(&model.document)?;
let inserted_id = get_or_prepend_id_field(&mut insert_document)?;
(rawdoc! { "document": insert_document }, Some(inserted_id))
}
_ => {
let model_document = crate::bson_compat::serialize_to_raw_document_buf(&self)?;
(model_document, None)
}
};

if let Some(multi) = self.multi() {
model_document.append(cstr!("multi"), multi);
ops_document.append(cstr!("multi"), multi);
}

Ok((model_document, inserted_id))
if let Self::InsertOne(model) = self {
let mut insert_document = RawDocumentBuf::try_from(&model.document)?;
let inserted_id = get_or_prepend_id_field(&mut insert_document)?;
ops_document.append(cstr!("document"), insert_document);
Ok((ops_document, Some(inserted_id)))
} else {
let model = serialize_to_raw_document_buf(&self)?;
extend_raw_document_buf(&mut ops_document, model)?;
Ok((ops_document, None))
}
}
}
10 changes: 8 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,14 @@ pub type Result<T> = std::result::Result<T, Error>;
/// [`ErrorKind`](enum.ErrorKind.html) is wrapped in an `Box` to allow the errors to be
/// cloned.
#[derive(Clone, Debug, Error)]
#[cfg_attr(test, error("Kind: {kind}, labels: {labels:?}, backtrace: {bt}"))]
#[cfg_attr(not(test), error("Kind: {kind}, labels: {labels:?}"))]
#[cfg_attr(
test,
error("Kind: {kind}, labels: {labels:?}, source: {source:?}, backtrace: {bt}")
)]
#[cfg_attr(
not(test),
error("Kind: {kind}, labels: {labels:?}, source: {source:?}")
)]
#[non_exhaustive]
pub struct Error {
/// The type of error that occurred.
Expand Down
Loading