1
1
use crate :: tensor:: Tensor ;
2
2
use crate :: { cstr, drop_using_function, try_unsafe, util:: Result } ;
3
3
use openvino_sys:: {
4
- ov_infer_request_free, ov_infer_request_get_tensor, ov_infer_request_infer,
5
- ov_infer_request_set_tensor, ov_infer_request_start_async, ov_infer_request_t,
6
- ov_infer_request_wait_for,
4
+ ov_infer_request_free, ov_infer_request_get_output_tensor_by_index,
5
+ ov_infer_request_get_tensor, ov_infer_request_infer,
6
+ ov_infer_request_set_input_tensor_by_index, ov_infer_request_set_tensor,
7
+ ov_infer_request_start_async, ov_infer_request_t, ov_infer_request_wait_for,
7
8
} ;
8
9
9
10
/// See [`InferRequest`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__infer__request__c__api.html).
@@ -20,6 +21,7 @@ impl InferRequest {
20
21
pub ( crate ) fn from_ptr ( ptr : * mut ov_infer_request_t ) -> Self {
21
22
Self { ptr }
22
23
}
24
+
23
25
/// Assign a [`Tensor`] to the input on the model.
24
26
pub fn set_tensor ( & mut self , name : & str , tensor : & Tensor ) -> Result < ( ) > {
25
27
try_unsafe ! ( ov_infer_request_set_tensor(
@@ -41,6 +43,27 @@ impl InferRequest {
41
43
Ok ( Tensor :: from_ptr ( tensor) )
42
44
}
43
45
46
+ /// Assing an input [`Tensor`] to the model by its index.
47
+ pub fn set_input_tensor_by_index ( & mut self , index : usize , tensor : & Tensor ) -> Result < ( ) > {
48
+ try_unsafe ! ( ov_infer_request_set_input_tensor_by_index(
49
+ self . ptr,
50
+ index,
51
+ tensor. as_ptr( )
52
+ ) ) ?;
53
+ Ok ( ( ) )
54
+ }
55
+
56
+ /// Retrieve an output [`Tensor`] from the model by its index.
57
+ pub fn get_output_tensor_by_index ( & self , index : usize ) -> Result < Tensor > {
58
+ let mut tensor = std:: ptr:: null_mut ( ) ;
59
+ try_unsafe ! ( ov_infer_request_get_output_tensor_by_index(
60
+ self . ptr,
61
+ index,
62
+ std:: ptr:: addr_of_mut!( tensor)
63
+ ) ) ?;
64
+ Ok ( Tensor :: from_ptr ( tensor) )
65
+ }
66
+
44
67
/// Execute the inference request.
45
68
pub fn infer ( & mut self ) -> Result < ( ) > {
46
69
try_unsafe ! ( ov_infer_request_infer( self . ptr) )
0 commit comments