Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions crates/openvino/src/blob.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::tensor_desc::TensorDesc;
use crate::{drop_using_function, try_unsafe, util::Result, InferenceError};
use crate::{Layout, Precision};
use openvino_sys::{
self, dimensions_t, ie_blob_buffer__bindgen_ty_1, ie_blob_buffer_t, ie_blob_byte_size,
ie_blob_free, ie_blob_get_buffer, ie_blob_get_dims, ie_blob_get_layout, ie_blob_get_precision,
ie_blob_make_memory, ie_blob_size, ie_blob_t,
self, ie_blob_buffer__bindgen_ty_1, ie_blob_buffer_t, ie_blob_byte_size, ie_blob_free,
ie_blob_get_buffer, ie_blob_get_dims, ie_blob_get_layout, ie_blob_get_precision,
ie_blob_make_memory, ie_blob_size, ie_blob_t, tensor_desc_t,
};
use std::convert::TryFrom;
use std::mem::MaybeUninit;

/// See [`Blob`](https://docs.openvinotoolkit.org/latest/classInferenceEngine_1_1Blob.html).
pub struct Blob {
Expand Down Expand Up @@ -53,22 +53,25 @@ impl Blob {
pub fn tensor_desc(&self) -> Result<TensorDesc> {
let blob = self.instance as *const ie_blob_t;

let mut layout = Layout::ANY;
try_unsafe!(ie_blob_get_layout(blob, std::ptr::addr_of_mut!(layout)))?;
let mut layout = MaybeUninit::uninit();
try_unsafe!(ie_blob_get_layout(blob, layout.as_mut_ptr()))?;

let mut dimensions = dimensions_t {
ranks: 0,
dims: [0; 8usize],
};
try_unsafe!(ie_blob_get_dims(blob, std::ptr::addr_of_mut!(dimensions)))?;
let mut dimensions = MaybeUninit::uninit();
try_unsafe!(ie_blob_get_dims(blob, dimensions.as_mut_ptr()))?;

let mut precision = Precision::UNSPECIFIED;
try_unsafe!(ie_blob_get_precision(
blob,
std::ptr::addr_of_mut!(precision)
))?;
let mut precision = MaybeUninit::uninit();
try_unsafe!(ie_blob_get_precision(blob, precision.as_mut_ptr()))?;

Ok(TensorDesc::new(layout, &dimensions.dims, precision))
Ok(TensorDesc {
// Safety: all reads succeeded so values must be initialized
instance: unsafe {
tensor_desc_t {
layout: layout.assume_init(),
dims: dimensions.assume_init(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm fine with the MaybeUninit changes but can we add a double-check assertion that dimension.dims is an array of size 8? crates/openvino-sys/src/generated/types.rs has a definition like the following:

pub struct dimensions {
    pub ranks: usize,
    pub dims: [usize; 8usize],
}

And if for some reason some version of OpenVINO changed this, I would want to know with a panic, not some undefined behavior. See TensorDesc::new() for how this looks there.

Copy link
Contributor Author

@chemicstry chemicstry Dec 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand what you mean. MaybeUninit type is automatically inferred from dimensions_t (aka dimensions). So in this case it is MaybeUninit<dimensions_t> and if dimensions were changed, the type of MaybeUninit would change accordingly.

Or do you want to ensure that when bindings are regenerated dimensions_t dims remain a size of 8? In that case I believe a compile time assertion in bindings would be better than a runtime check.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it is probably too complicated to do so in the bindings since they are auto-generated and that still doesn't solve the undefined-ness of potentially storing a wrongly-sized array in that field. It's a bit paranoid of me, I guess (why would OpenVINO itself return the wrong size anyways!?), but I would feel better if there is some check there. I'll merge this as-is and add something in a follow-on commit.

precision: precision.assume_init(),
}
},
})
}

/// Get the number of elements contained in the [`Blob`].
Expand Down Expand Up @@ -186,6 +189,7 @@ impl Blob {
#[cfg(test)]
mod tests {
use super::*;
use crate::{Layout, Precision};

#[test]
#[should_panic]
Expand Down Expand Up @@ -223,4 +227,18 @@ mod tests {
"we should have half as many items (u16 = f32 / 2)"
);
}

#[test]
fn tensor_desc() {
openvino_sys::library::load().expect("unable to find an OpenVINO shared library");

let desc = TensorDesc::new(Layout::NHWC, &[1, 2, 2, 2], Precision::U8);
let blob = Blob::new(&desc, &[0; 8]).unwrap();
let desc2 = blob.tensor_desc().unwrap();

// Both TensorDesc's should be equal
assert_eq!(desc.layout(), desc2.layout());
assert_eq!(desc.dims(), desc2.dims());
assert_eq!(desc.precision(), desc2.precision());
}
}