Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 124 additions & 3 deletions crates/openvino/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@

use crate::node::Node;
use crate::request::InferRequest;
use crate::{drop_using_function, try_unsafe, util::Result};
use crate::{cstr, drop_using_function, try_unsafe, util::Result, PropertyKey, RwPropertyKey};
use openvino_sys::{
ov_compiled_model_create_infer_request, ov_compiled_model_free, ov_compiled_model_t,
ov_compiled_model_create_infer_request, ov_compiled_model_free, ov_compiled_model_get_property,
ov_compiled_model_get_runtime_model, ov_compiled_model_input, ov_compiled_model_input_by_index,
ov_compiled_model_input_by_name, ov_compiled_model_inputs_size, ov_compiled_model_output,
ov_compiled_model_output_by_index, ov_compiled_model_output_by_name,
ov_compiled_model_outputs_size, ov_compiled_model_set_property, ov_compiled_model_t,
ov_model_const_input_by_index, ov_model_const_output_by_index, ov_model_free,
ov_model_inputs_size, ov_model_is_dynamic, ov_model_outputs_size, ov_model_t,
};
use std::borrow::Cow;
use std::ffi::CStr;

/// See [`Model`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__model__c__api.html).
/// See [`Model`](https://docs.openvino.ai/2024/api/c_cpp_api/group__ov__model__c__api.html).
pub struct Model {
ptr: *mut ov_model_t,
}
Expand Down Expand Up @@ -107,4 +113,119 @@ impl CompiledModel {
))?;
Ok(InferRequest::from_ptr(infer_request))
}

/// Get the number of inputs of the compiled model.
pub fn get_input_size(&self) -> Result<usize> {
let mut input_size: usize = 0;
try_unsafe!(ov_compiled_model_inputs_size(self.ptr, &mut input_size))?;
Ok(input_size)
}

/// Get the single input port of the compiled model, which only support single input model.
pub fn get_input(&self) -> Result<Node> {
let mut port = std::ptr::null_mut();
try_unsafe!(ov_compiled_model_input(
self.ptr,
std::ptr::addr_of_mut!(port)
))?;
Ok(Node::new(port))
}

/// Get an input port of the compiled model by port index.
pub fn get_input_by_index(&self, index: usize) -> Result<Node> {
let mut port = std::ptr::null_mut();
try_unsafe!(ov_compiled_model_input_by_index(
self.ptr,
index,
std::ptr::addr_of_mut!(port)
))?;
Ok(Node::new(port))
}

/// Get an input port of the compiled model by name.
pub fn get_input_by_name(&self, name: &str) -> Result<Node> {
let name = cstr!(name);
let mut port = std::ptr::null_mut();
try_unsafe!(ov_compiled_model_input_by_name(
self.ptr,
name,
std::ptr::addr_of_mut!(port)
))?;
Ok(Node::new(port))
}

/// Get the number of outputs of the compiled model.
pub fn get_output_size(&self) -> Result<usize> {
let mut output_size: usize = 0;
try_unsafe!(ov_compiled_model_outputs_size(self.ptr, &mut output_size))?;
Ok(output_size)
}

/// Get the single output port of the compiled model, which only support single output model.
pub fn get_output(&self) -> Result<Node> {
let mut port = std::ptr::null_mut();
try_unsafe!(ov_compiled_model_output(
self.ptr,
std::ptr::addr_of_mut!(port)
))?;
Ok(Node::new(port))
}

/// Get an output port of the compiled model by port index.
pub fn get_output_by_index(&self, index: usize) -> Result<Node> {
let mut port = std::ptr::null_mut();
try_unsafe!(ov_compiled_model_output_by_index(
self.ptr,
index,
std::ptr::addr_of_mut!(port)
))?;
Ok(Node::new(port))
}

/// Get an output port of the compiled model by name.
pub fn get_output_by_name(&self, name: &str) -> Result<Node> {
let name = cstr!(name);
let mut port = std::ptr::null_mut();
try_unsafe!(ov_compiled_model_output_by_name(
self.ptr,
name,
std::ptr::addr_of_mut!(port)
))?;
Ok(Node::new(port))
}

/// Gets runtime model information from a device.
pub fn get_runtime_model(&self) -> Result<Model> {
let mut ptr = std::ptr::null_mut();
try_unsafe!(ov_compiled_model_get_runtime_model(
self.ptr,
std::ptr::addr_of_mut!(ptr)
))?;
Ok(Model { ptr })
}

/// Gets a property for the compiled model.
pub fn get_property(&self, key: PropertyKey) -> Result<Cow<str>> {
let ov_prop_key = cstr!(key.as_ref());
let mut ov_prop_value = std::ptr::null_mut();
try_unsafe!(ov_compiled_model_get_property(
self.ptr,
ov_prop_key,
std::ptr::addr_of_mut!(ov_prop_value)
))?;
let rust_prop = unsafe { CStr::from_ptr(ov_prop_value) }.to_string_lossy();
Ok(rust_prop)
}

/// Sets a property for the compiled model.
pub fn set_property(&mut self, key: RwPropertyKey, value: &str) -> Result<()> {
let ov_prop_key = cstr!(key.as_ref());
let ov_prop_value = cstr!(value);
try_unsafe!(ov_compiled_model_set_property(
self.ptr,
ov_prop_key,
ov_prop_value,
))?;
Ok(())
}
}