diff --git a/src/cmap/conn.rs b/src/cmap/conn.rs index 85b9ade9d..d7c39cca7 100644 --- a/src/cmap/conn.rs +++ b/src/cmap/conn.rs @@ -222,6 +222,7 @@ impl Connection { self.command_executing = true; + let max_message_size = self.max_message_size_bytes(); #[cfg(any( feature = "zstd-compression", feature = "zlib-compression", @@ -230,30 +231,30 @@ impl Connection { let write_result = match self.compressor { Some(ref compressor) if message.should_compress => { message - .write_op_compressed_to(&mut self.stream, compressor) + .write_op_compressed_to(&mut self.stream, compressor, max_message_size) + .await + } + _ => { + message + .write_op_msg_to(&mut self.stream, max_message_size) .await } - _ => message.write_op_msg_to(&mut self.stream).await, }; #[cfg(all( not(feature = "zstd-compression"), not(feature = "zlib-compression"), not(feature = "snappy-compression") ))] - let write_result = message.write_op_msg_to(&mut self.stream).await; + let write_result = message + .write_op_msg_to(&mut self.stream, max_message_size) + .await; if let Err(ref err) = write_result { self.error = Some(err.clone()); } write_result?; - let response_message_result = Message::read_from( - &mut self.stream, - self.stream_description - .as_ref() - .map(|d| d.max_message_size_bytes), - ) - .await; + let response_message_result = Message::read_from(&mut self.stream, max_message_size).await; self.command_executing = false; if let Err(ref err) = response_message_result { self.error = Some(err.clone()); @@ -306,6 +307,12 @@ impl Connection { pub(crate) fn is_streaming(&self) -> bool { self.more_to_come } + + fn max_message_size_bytes(&self) -> Option { + self.stream_description + .as_ref() + .map(|d| d.max_message_size_bytes) + } } /// A handle to a pinned connection - the connection itself can be retrieved or returned to the diff --git a/src/cmap/conn/wire/message.rs b/src/cmap/conn/wire/message.rs index c746c8b95..fcece7a1f 100644 --- a/src/cmap/conn/wire/message.rs +++ b/src/cmap/conn/wire/message.rs @@ -274,6 +274,7 @@ impl Message { pub(crate) async fn write_op_msg_to( &self, mut writer: T, + max_message_size_bytes: Option, ) -> Result<()> { let sections = self.get_sections_bytes()?; @@ -286,6 +287,15 @@ impl Message { .map(std::mem::size_of_val) .unwrap_or(0); + let max_len = + Checked::try_from(max_message_size_bytes.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE_BYTES))?; + if total_length > max_len { + return Err(ErrorKind::InvalidArgument { + message: format!("Message length {} over maximum {}", total_length, max_len), + } + .into()); + } + let header = Header { length: total_length.try_into()?, request_id: self.request_id.unwrap_or_else(next_request_id), @@ -316,6 +326,7 @@ impl Message { &self, mut writer: T, compressor: &Compressor, + max_message_size_bytes: Option, ) -> Result<()> { let flag_bytes = &self.flags.bits().to_le_bytes(); let section_bytes = self.get_sections_bytes()?; @@ -329,6 +340,15 @@ impl Message { + std::mem::size_of::() + compressed_bytes.len(); + let max_len = + Checked::try_from(max_message_size_bytes.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE_BYTES))?; + if total_length > max_len { + return Err(ErrorKind::InvalidArgument { + message: format!("Message length {} over maximum {}", total_length, max_len), + } + .into()); + } + let header = Header { length: total_length.try_into()?, request_id: self.request_id.unwrap_or_else(next_request_id),