Skip to content

Commit 7a562f0

Browse files
authored
Replace Tensor::buffer_mut (#127)
This refactors the data accessor functions in `Tensor` to be more consistent with conventions elsewhere (e.g., `get_`). It also checks a bit more robustly whether the underlying pointer can in fact be casted to the type we expect.
1 parent b6eacef commit 7a562f0

File tree

5 files changed

+77
-33
lines changed

5 files changed

+77
-33
lines changed

crates/openvino/src/prepostprocess.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
//! # let data = fs::read("tests/fixtures/inception/tensor-1x3x299x299-f32.bgr").expect("to read the tensor from file");
1414
//! # let input_shape = Shape::new(&vec![1, 299, 299, 3]).expect("to create a new shape");
1515
//! # let mut tensor = Tensor::new(ElementType::F32, &input_shape).expect("to create a new tensor");
16-
//! # let buffer = tensor.buffer_mut().unwrap();
16+
//! # let buffer = tensor.get_raw_data_mut().unwrap();
1717
//! # buffer.copy_from_slice(&data);
18-
//! // Insantiate a new core, read in a model, and set up a tensor with input data before performing pre/post processing
18+
//! // Instantiate a new core, read in a model, and set up a tensor with input data before performing pre/post processing
1919
//! // Pre-process the input by:
2020
//! // - converting NHWC to NCHW
2121
//! // - resizing the input image

crates/openvino/src/tensor.rs

Lines changed: 72 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -86,49 +86,79 @@ impl Tensor {
8686
Ok(byte_size)
8787
}
8888

89-
/// Get a mutable reference to the data of the tensor.
90-
pub fn get_data<T>(&mut self) -> Result<&mut [T]> {
91-
let mut data = std::ptr::null_mut();
92-
try_unsafe!(ov_tensor_data(self.ptr, std::ptr::addr_of_mut!(data),))?;
93-
let size = self.get_byte_size()? / std::mem::size_of::<T>();
94-
let slice = unsafe { std::slice::from_raw_parts_mut(data.cast::<T>(), size) };
89+
/// Get the underlying data for the tensor.
90+
pub fn get_raw_data(&self) -> Result<&[u8]> {
91+
let mut buffer = std::ptr::null_mut();
92+
try_unsafe!(ov_tensor_data(self.ptr, std::ptr::addr_of_mut!(buffer)))?;
93+
let size = self.get_byte_size()?;
94+
let slice = unsafe { std::slice::from_raw_parts(buffer.cast::<u8>(), size) };
9595
Ok(slice)
9696
}
9797

98-
/// Get a mutable reference to the buffer of the tensor.
99-
///
100-
/// # Returns
101-
///
102-
/// A mutable reference to the buffer of the tensor.
103-
pub fn buffer_mut(&mut self) -> Result<&mut [u8]> {
98+
/// Get a mutable reference to the underlying data for the tensor.
99+
pub fn get_raw_data_mut(&mut self) -> Result<&mut [u8]> {
104100
let mut buffer = std::ptr::null_mut();
105101
try_unsafe!(ov_tensor_data(self.ptr, std::ptr::addr_of_mut!(buffer)))?;
106102
let size = self.get_byte_size()?;
107103
let slice = unsafe { std::slice::from_raw_parts_mut(buffer.cast::<u8>(), size) };
108104
Ok(slice)
109105
}
106+
107+
/// Get a `T`-casted slice of the underlying data for the tensor.
108+
///
109+
/// # Panics
110+
///
111+
/// This method will panic if it can't cast the data to `T` due to the type size or the
112+
/// underlying pointer's alignment.
113+
pub fn get_data<T>(&self) -> Result<&[T]> {
114+
let raw_data = self.get_raw_data()?;
115+
let len = get_safe_len::<T>(raw_data);
116+
let slice = unsafe { std::slice::from_raw_parts(raw_data.as_ptr().cast::<T>(), len) };
117+
Ok(slice)
118+
}
119+
120+
/// Get a mutable `T`-casted slice of the underlying data for the tensor.
121+
///
122+
/// # Panics
123+
///
124+
/// This method will panic if it can't cast the data to `T` due to the type size or the
125+
/// underlying pointer's alignment.
126+
pub fn get_data_mut<T>(&mut self) -> Result<&mut [T]> {
127+
let raw_data = self.get_raw_data_mut()?;
128+
let len = get_safe_len::<T>(raw_data);
129+
let slice =
130+
unsafe { std::slice::from_raw_parts_mut(raw_data.as_mut_ptr().cast::<T>(), len) };
131+
Ok(slice)
132+
}
133+
}
134+
135+
/// Convenience function for checking that we can cast `data` to a slice of `T`, returning the
136+
/// length of that slice.
137+
fn get_safe_len<T>(data: &[u8]) -> usize {
138+
if data.len() % std::mem::size_of::<T>() != 0 {
139+
panic!("data size is not a multiple of the size of `T`");
140+
}
141+
if data.as_ptr() as usize % std::mem::align_of::<T>() != 0 {
142+
panic!("raw data is not aligned to `T`'s alignment");
143+
}
144+
data.len() / std::mem::size_of::<T>()
110145
}
111146

112147
#[cfg(test)]
113148
mod tests {
114149
use super::*;
115-
use crate::{ElementType, LoadingError, Shape};
116150

117151
#[test]
118152
fn test_create_tensor() {
119-
openvino_sys::library::load()
120-
.map_err(LoadingError::SystemFailure)
121-
.unwrap();
153+
openvino_sys::library::load().unwrap();
122154
let shape = Shape::new(&vec![1, 3, 227, 227]).unwrap();
123155
let tensor = Tensor::new(ElementType::F32, &shape).unwrap();
124156
assert!(!tensor.ptr.is_null());
125157
}
126158

127159
#[test]
128160
fn test_get_shape() {
129-
openvino_sys::library::load()
130-
.map_err(LoadingError::SystemFailure)
131-
.unwrap();
161+
openvino_sys::library::load().unwrap();
132162
let tensor = Tensor::new(
133163
ElementType::F32,
134164
&Shape::new(&vec![1, 3, 227, 227]).unwrap(),
@@ -140,9 +170,7 @@ mod tests {
140170

141171
#[test]
142172
fn test_get_element_type() {
143-
openvino_sys::library::load()
144-
.map_err(LoadingError::SystemFailure)
145-
.unwrap();
173+
openvino_sys::library::load().unwrap();
146174
let tensor = Tensor::new(
147175
ElementType::F32,
148176
&Shape::new(&vec![1, 3, 227, 227]).unwrap(),
@@ -154,9 +182,7 @@ mod tests {
154182

155183
#[test]
156184
fn test_get_size() {
157-
openvino_sys::library::load()
158-
.map_err(LoadingError::SystemFailure)
159-
.unwrap();
185+
openvino_sys::library::load().unwrap();
160186
let tensor = Tensor::new(
161187
ElementType::F32,
162188
&Shape::new(&vec![1, 3, 227, 227]).unwrap(),
@@ -168,9 +194,7 @@ mod tests {
168194

169195
#[test]
170196
fn test_get_byte_size() {
171-
openvino_sys::library::load()
172-
.map_err(LoadingError::SystemFailure)
173-
.unwrap();
197+
openvino_sys::library::load().unwrap();
174198
let tensor = Tensor::new(
175199
ElementType::F32,
176200
&Shape::new(&vec![1, 3, 227, 227]).unwrap(),
@@ -182,4 +206,24 @@ mod tests {
182206
1 * 3 * 227 * 227 * std::mem::size_of::<f32>() as usize
183207
);
184208
}
209+
210+
#[test]
211+
fn casting() {
212+
openvino_sys::library::load().unwrap();
213+
let shape = Shape::new(&vec![10, 10, 10]).unwrap();
214+
let tensor = Tensor::new(ElementType::F32, &shape).unwrap();
215+
let data = tensor.get_data::<f32>().unwrap();
216+
assert_eq!(data.len(), 10 * 10 * 10);
217+
}
218+
219+
#[test]
220+
#[should_panic(expected = "data size is not a multiple of the size of `T`")]
221+
fn casting_check() {
222+
openvino_sys::library::load().unwrap();
223+
let shape = Shape::new(&vec![10, 10, 10]).unwrap();
224+
let tensor = Tensor::new(ElementType::F32, &shape).unwrap();
225+
#[allow(dead_code)]
226+
struct LargeOddType([u8; 1061]);
227+
tensor.get_data::<LargeOddType>().unwrap();
228+
}
185229
}

crates/openvino/tests/classify-alexnet.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ fn classify_alexnet() -> anyhow::Result<()> {
2828
let input_shape = Shape::new(&vec![1, 227, 227, 3])?;
2929
let element_type = ElementType::F32;
3030
let mut tensor = Tensor::new(element_type, &input_shape)?;
31-
let buffer = tensor.buffer_mut()?;
31+
let buffer = tensor.get_raw_data_mut()?;
3232
buffer.copy_from_slice(&data);
3333

3434
// Pre-process the input by:

crates/openvino/tests/classify-inception.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ fn classify_inception() -> anyhow::Result<()> {
2828
let input_shape = Shape::new(&vec![1, 299, 299, 3])?;
2929
let element_type = ElementType::F32;
3030
let mut tensor = Tensor::new(element_type, &input_shape)?;
31-
let buffer = tensor.buffer_mut()?;
31+
let buffer = tensor.get_raw_data_mut()?;
3232
buffer.copy_from_slice(&data);
3333

3434
// Pre-process the input by:

crates/openvino/tests/classify-mobilenet.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ fn classify_mobilenet() -> anyhow::Result<()> {
2828
let input_shape = Shape::new(&vec![1, 224, 224, 3])?;
2929
let element_type = ElementType::F32;
3030
let mut tensor = Tensor::new(element_type, &input_shape)?;
31-
let buffer = tensor.buffer_mut()?;
31+
let buffer = tensor.get_raw_data_mut()?;
3232
buffer.copy_from_slice(&data);
3333

3434
// Pre-process the input by:

0 commit comments

Comments
 (0)