Skip to content

Commit e5e2cc1

Browse files
authored
Fix Blob::tensor_desc() (#56)
* Fix and optimise Blob::tensor_desc() * Load shared library for Blob::tensor_desc() test
1 parent cb7fa2a commit e5e2cc1

File tree

1 file changed

+35
-17
lines changed

1 file changed

+35
-17
lines changed

crates/openvino/src/blob.rs

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use crate::tensor_desc::TensorDesc;
22
use crate::{drop_using_function, try_unsafe, util::Result, InferenceError};
3-
use crate::{Layout, Precision};
43
use openvino_sys::{
5-
self, dimensions_t, ie_blob_buffer__bindgen_ty_1, ie_blob_buffer_t, ie_blob_byte_size,
6-
ie_blob_free, ie_blob_get_buffer, ie_blob_get_dims, ie_blob_get_layout, ie_blob_get_precision,
7-
ie_blob_make_memory, ie_blob_size, ie_blob_t,
4+
self, ie_blob_buffer__bindgen_ty_1, ie_blob_buffer_t, ie_blob_byte_size, ie_blob_free,
5+
ie_blob_get_buffer, ie_blob_get_dims, ie_blob_get_layout, ie_blob_get_precision,
6+
ie_blob_make_memory, ie_blob_size, ie_blob_t, tensor_desc_t,
87
};
98
use std::convert::TryFrom;
9+
use std::mem::MaybeUninit;
1010

1111
/// See [`Blob`](https://docs.openvinotoolkit.org/latest/classInferenceEngine_1_1Blob.html).
1212
pub struct Blob {
@@ -53,22 +53,25 @@ impl Blob {
5353
pub fn tensor_desc(&self) -> Result<TensorDesc> {
5454
let blob = self.instance as *const ie_blob_t;
5555

56-
let mut layout = Layout::ANY;
57-
try_unsafe!(ie_blob_get_layout(blob, std::ptr::addr_of_mut!(layout)))?;
56+
let mut layout = MaybeUninit::uninit();
57+
try_unsafe!(ie_blob_get_layout(blob, layout.as_mut_ptr()))?;
5858

59-
let mut dimensions = dimensions_t {
60-
ranks: 0,
61-
dims: [0; 8usize],
62-
};
63-
try_unsafe!(ie_blob_get_dims(blob, std::ptr::addr_of_mut!(dimensions)))?;
59+
let mut dimensions = MaybeUninit::uninit();
60+
try_unsafe!(ie_blob_get_dims(blob, dimensions.as_mut_ptr()))?;
6461

65-
let mut precision = Precision::UNSPECIFIED;
66-
try_unsafe!(ie_blob_get_precision(
67-
blob,
68-
std::ptr::addr_of_mut!(precision)
69-
))?;
62+
let mut precision = MaybeUninit::uninit();
63+
try_unsafe!(ie_blob_get_precision(blob, precision.as_mut_ptr()))?;
7064

71-
Ok(TensorDesc::new(layout, &dimensions.dims, precision))
65+
Ok(TensorDesc {
66+
// Safety: all reads succeeded so values must be initialized
67+
instance: unsafe {
68+
tensor_desc_t {
69+
layout: layout.assume_init(),
70+
dims: dimensions.assume_init(),
71+
precision: precision.assume_init(),
72+
}
73+
},
74+
})
7275
}
7376

7477
/// Get the number of elements contained in the [`Blob`].
@@ -186,6 +189,7 @@ impl Blob {
186189
#[cfg(test)]
187190
mod tests {
188191
use super::*;
192+
use crate::{Layout, Precision};
189193

190194
#[test]
191195
#[should_panic]
@@ -223,4 +227,18 @@ mod tests {
223227
"we should have half as many items (u16 = f32 / 2)"
224228
);
225229
}
230+
231+
#[test]
232+
fn tensor_desc() {
233+
openvino_sys::library::load().expect("unable to find an OpenVINO shared library");
234+
235+
let desc = TensorDesc::new(Layout::NHWC, &[1, 2, 2, 2], Precision::U8);
236+
let blob = Blob::new(&desc, &[0; 8]).unwrap();
237+
let desc2 = blob.tensor_desc().unwrap();
238+
239+
// Both TensorDesc's should be equal
240+
assert_eq!(desc.layout(), desc2.layout());
241+
assert_eq!(desc.dims(), desc2.dims());
242+
assert_eq!(desc.precision(), desc2.precision());
243+
}
226244
}

0 commit comments

Comments
 (0)