Skip to content

Commit 3a404ee

Browse files
authored
Expose setting intput/getting output tensors on the model (#106)
Expose setting input/getting output tensors by index
1 parent 836dd87 commit 3a404ee

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

crates/openvino/src/request.rs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use crate::tensor::Tensor;
22
use crate::{cstr, drop_using_function, try_unsafe, util::Result};
33
use openvino_sys::{
4-
ov_infer_request_free, ov_infer_request_get_tensor, ov_infer_request_infer,
5-
ov_infer_request_set_tensor, ov_infer_request_start_async, ov_infer_request_t,
6-
ov_infer_request_wait_for,
4+
ov_infer_request_free, ov_infer_request_get_output_tensor_by_index,
5+
ov_infer_request_get_tensor, ov_infer_request_infer,
6+
ov_infer_request_set_input_tensor_by_index, ov_infer_request_set_tensor,
7+
ov_infer_request_start_async, ov_infer_request_t, ov_infer_request_wait_for,
78
};
89

910
/// See [`InferRequest`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__infer__request__c__api.html).
@@ -20,6 +21,7 @@ impl InferRequest {
2021
pub(crate) fn from_ptr(ptr: *mut ov_infer_request_t) -> Self {
2122
Self { ptr }
2223
}
24+
2325
/// Assign a [`Tensor`] to the input on the model.
2426
pub fn set_tensor(&mut self, name: &str, tensor: &Tensor) -> Result<()> {
2527
try_unsafe!(ov_infer_request_set_tensor(
@@ -41,6 +43,27 @@ impl InferRequest {
4143
Ok(Tensor::from_ptr(tensor))
4244
}
4345

46+
/// Assing an input [`Tensor`] to the model by its index.
47+
pub fn set_input_tensor_by_index(&mut self, index: usize, tensor: &Tensor) -> Result<()> {
48+
try_unsafe!(ov_infer_request_set_input_tensor_by_index(
49+
self.ptr,
50+
index,
51+
tensor.as_ptr()
52+
))?;
53+
Ok(())
54+
}
55+
56+
/// Retrieve an output [`Tensor`] from the model by its index.
57+
pub fn get_output_tensor_by_index(&self, index: usize) -> Result<Tensor> {
58+
let mut tensor = std::ptr::null_mut();
59+
try_unsafe!(ov_infer_request_get_output_tensor_by_index(
60+
self.ptr,
61+
index,
62+
std::ptr::addr_of_mut!(tensor)
63+
))?;
64+
Ok(Tensor::from_ptr(tensor))
65+
}
66+
4467
/// Execute the inference request.
4568
pub fn infer(&mut self) -> Result<()> {
4669
try_unsafe!(ov_infer_request_infer(self.ptr))

0 commit comments

Comments
 (0)