|
1 | 1 | use crate::tensor_desc::TensorDesc;
|
2 | 2 | use crate::{drop_using_function, try_unsafe, util::Result, InferenceError};
|
3 |
| -use crate::{Layout, Precision}; |
4 | 3 | use openvino_sys::{
|
5 |
| - self, dimensions_t, ie_blob_buffer__bindgen_ty_1, ie_blob_buffer_t, ie_blob_byte_size, |
6 |
| - ie_blob_free, ie_blob_get_buffer, ie_blob_get_dims, ie_blob_get_layout, ie_blob_get_precision, |
7 |
| - ie_blob_make_memory, ie_blob_size, ie_blob_t, |
| 4 | + self, ie_blob_buffer__bindgen_ty_1, ie_blob_buffer_t, ie_blob_byte_size, ie_blob_free, |
| 5 | + ie_blob_get_buffer, ie_blob_get_dims, ie_blob_get_layout, ie_blob_get_precision, |
| 6 | + ie_blob_make_memory, ie_blob_size, ie_blob_t, tensor_desc_t, |
8 | 7 | };
|
9 | 8 | use std::convert::TryFrom;
|
| 9 | +use std::mem::MaybeUninit; |
10 | 10 |
|
11 | 11 | /// See [`Blob`](https://docs.openvinotoolkit.org/latest/classInferenceEngine_1_1Blob.html).
|
12 | 12 | pub struct Blob {
|
@@ -53,22 +53,25 @@ impl Blob {
|
53 | 53 | pub fn tensor_desc(&self) -> Result<TensorDesc> {
|
54 | 54 | let blob = self.instance as *const ie_blob_t;
|
55 | 55 |
|
56 |
| - let mut layout = Layout::ANY; |
57 |
| - try_unsafe!(ie_blob_get_layout(blob, std::ptr::addr_of_mut!(layout)))?; |
| 56 | + let mut layout = MaybeUninit::uninit(); |
| 57 | + try_unsafe!(ie_blob_get_layout(blob, layout.as_mut_ptr()))?; |
58 | 58 |
|
59 |
| - let mut dimensions = dimensions_t { |
60 |
| - ranks: 0, |
61 |
| - dims: [0; 8usize], |
62 |
| - }; |
63 |
| - try_unsafe!(ie_blob_get_dims(blob, std::ptr::addr_of_mut!(dimensions)))?; |
| 59 | + let mut dimensions = MaybeUninit::uninit(); |
| 60 | + try_unsafe!(ie_blob_get_dims(blob, dimensions.as_mut_ptr()))?; |
64 | 61 |
|
65 |
| - let mut precision = Precision::UNSPECIFIED; |
66 |
| - try_unsafe!(ie_blob_get_precision( |
67 |
| - blob, |
68 |
| - std::ptr::addr_of_mut!(precision) |
69 |
| - ))?; |
| 62 | + let mut precision = MaybeUninit::uninit(); |
| 63 | + try_unsafe!(ie_blob_get_precision(blob, precision.as_mut_ptr()))?; |
70 | 64 |
|
71 |
| - Ok(TensorDesc::new(layout, &dimensions.dims, precision)) |
| 65 | + Ok(TensorDesc { |
| 66 | + // Safety: all reads succeeded so values must be initialized |
| 67 | + instance: unsafe { |
| 68 | + tensor_desc_t { |
| 69 | + layout: layout.assume_init(), |
| 70 | + dims: dimensions.assume_init(), |
| 71 | + precision: precision.assume_init(), |
| 72 | + } |
| 73 | + }, |
| 74 | + }) |
72 | 75 | }
|
73 | 76 |
|
74 | 77 | /// Get the number of elements contained in the [`Blob`].
|
@@ -186,6 +189,7 @@ impl Blob {
|
186 | 189 | #[cfg(test)]
|
187 | 190 | mod tests {
|
188 | 191 | use super::*;
|
| 192 | + use crate::{Layout, Precision}; |
189 | 193 |
|
190 | 194 | #[test]
|
191 | 195 | #[should_panic]
|
@@ -223,4 +227,18 @@ mod tests {
|
223 | 227 | "we should have half as many items (u16 = f32 / 2)"
|
224 | 228 | );
|
225 | 229 | }
|
| 230 | + |
| 231 | + #[test] |
| 232 | + fn tensor_desc() { |
| 233 | + openvino_sys::library::load().expect("unable to find an OpenVINO shared library"); |
| 234 | + |
| 235 | + let desc = TensorDesc::new(Layout::NHWC, &[1, 2, 2, 2], Precision::U8); |
| 236 | + let blob = Blob::new(&desc, &[0; 8]).unwrap(); |
| 237 | + let desc2 = blob.tensor_desc().unwrap(); |
| 238 | + |
| 239 | + // Both TensorDesc's should be equal |
| 240 | + assert_eq!(desc.layout(), desc2.layout()); |
| 241 | + assert_eq!(desc.dims(), desc2.dims()); |
| 242 | + assert_eq!(desc.precision(), desc2.precision()); |
| 243 | + } |
226 | 244 | }
|
0 commit comments