diff --git a/crates/openvino/src/blob.rs b/crates/openvino/src/blob.rs index 20b060c..5372759 100644 --- a/crates/openvino/src/blob.rs +++ b/crates/openvino/src/blob.rs @@ -1,12 +1,12 @@ use crate::tensor_desc::TensorDesc; use crate::{drop_using_function, try_unsafe, util::Result, InferenceError}; -use crate::{Layout, Precision}; use openvino_sys::{ - self, dimensions_t, ie_blob_buffer__bindgen_ty_1, ie_blob_buffer_t, ie_blob_byte_size, - ie_blob_free, ie_blob_get_buffer, ie_blob_get_dims, ie_blob_get_layout, ie_blob_get_precision, - ie_blob_make_memory, ie_blob_size, ie_blob_t, + self, ie_blob_buffer__bindgen_ty_1, ie_blob_buffer_t, ie_blob_byte_size, ie_blob_free, + ie_blob_get_buffer, ie_blob_get_dims, ie_blob_get_layout, ie_blob_get_precision, + ie_blob_make_memory, ie_blob_size, ie_blob_t, tensor_desc_t, }; use std::convert::TryFrom; +use std::mem::MaybeUninit; /// See [`Blob`](https://docs.openvinotoolkit.org/latest/classInferenceEngine_1_1Blob.html). pub struct Blob { @@ -53,22 +53,25 @@ impl Blob { pub fn tensor_desc(&self) -> Result { let blob = self.instance as *const ie_blob_t; - let mut layout = Layout::ANY; - try_unsafe!(ie_blob_get_layout(blob, std::ptr::addr_of_mut!(layout)))?; + let mut layout = MaybeUninit::uninit(); + try_unsafe!(ie_blob_get_layout(blob, layout.as_mut_ptr()))?; - let mut dimensions = dimensions_t { - ranks: 0, - dims: [0; 8usize], - }; - try_unsafe!(ie_blob_get_dims(blob, std::ptr::addr_of_mut!(dimensions)))?; + let mut dimensions = MaybeUninit::uninit(); + try_unsafe!(ie_blob_get_dims(blob, dimensions.as_mut_ptr()))?; - let mut precision = Precision::UNSPECIFIED; - try_unsafe!(ie_blob_get_precision( - blob, - std::ptr::addr_of_mut!(precision) - ))?; + let mut precision = MaybeUninit::uninit(); + try_unsafe!(ie_blob_get_precision(blob, precision.as_mut_ptr()))?; - Ok(TensorDesc::new(layout, &dimensions.dims, precision)) + Ok(TensorDesc { + // Safety: all reads succeeded so values must be initialized + instance: unsafe { + tensor_desc_t { + layout: layout.assume_init(), + dims: dimensions.assume_init(), + precision: precision.assume_init(), + } + }, + }) } /// Get the number of elements contained in the [`Blob`]. @@ -186,6 +189,7 @@ impl Blob { #[cfg(test)] mod tests { use super::*; + use crate::{Layout, Precision}; #[test] #[should_panic] @@ -223,4 +227,18 @@ mod tests { "we should have half as many items (u16 = f32 / 2)" ); } + + #[test] + fn tensor_desc() { + openvino_sys::library::load().expect("unable to find an OpenVINO shared library"); + + let desc = TensorDesc::new(Layout::NHWC, &[1, 2, 2, 2], Precision::U8); + let blob = Blob::new(&desc, &[0; 8]).unwrap(); + let desc2 = blob.tensor_desc().unwrap(); + + // Both TensorDesc's should be equal + assert_eq!(desc.layout(), desc2.layout()); + assert_eq!(desc.dims(), desc2.dims()); + assert_eq!(desc.precision(), desc2.precision()); + } }