@@ -86,49 +86,79 @@ impl Tensor {
86
86
Ok ( byte_size)
87
87
}
88
88
89
- /// Get a mutable reference to the data of the tensor.
90
- pub fn get_data < T > ( & mut self ) -> Result < & mut [ T ] > {
91
- let mut data = std:: ptr:: null_mut ( ) ;
92
- try_unsafe ! ( ov_tensor_data( self . ptr, std:: ptr:: addr_of_mut!( data ) , ) ) ?;
93
- let size = self . get_byte_size ( ) ? / std :: mem :: size_of :: < T > ( ) ;
94
- let slice = unsafe { std:: slice:: from_raw_parts_mut ( data . cast :: < T > ( ) , size) } ;
89
+ /// Get the underlying data for the tensor.
90
+ pub fn get_raw_data ( & self ) -> Result < & [ u8 ] > {
91
+ let mut buffer = std:: ptr:: null_mut ( ) ;
92
+ try_unsafe ! ( ov_tensor_data( self . ptr, std:: ptr:: addr_of_mut!( buffer ) ) ) ?;
93
+ let size = self . get_byte_size ( ) ?;
94
+ let slice = unsafe { std:: slice:: from_raw_parts ( buffer . cast :: < u8 > ( ) , size) } ;
95
95
Ok ( slice)
96
96
}
97
97
98
- /// Get a mutable reference to the buffer of the tensor.
99
- ///
100
- /// # Returns
101
- ///
102
- /// A mutable reference to the buffer of the tensor.
103
- pub fn buffer_mut ( & mut self ) -> Result < & mut [ u8 ] > {
98
+ /// Get a mutable reference to the underlying data for the tensor.
99
+ pub fn get_raw_data_mut ( & mut self ) -> Result < & mut [ u8 ] > {
104
100
let mut buffer = std:: ptr:: null_mut ( ) ;
105
101
try_unsafe ! ( ov_tensor_data( self . ptr, std:: ptr:: addr_of_mut!( buffer) ) ) ?;
106
102
let size = self . get_byte_size ( ) ?;
107
103
let slice = unsafe { std:: slice:: from_raw_parts_mut ( buffer. cast :: < u8 > ( ) , size) } ;
108
104
Ok ( slice)
109
105
}
106
+
107
+ /// Get a `T`-casted slice of the underlying data for the tensor.
108
+ ///
109
+ /// # Panics
110
+ ///
111
+ /// This method will panic if it can't cast the data to `T` due to the type size or the
112
+ /// underlying pointer's alignment.
113
+ pub fn get_data < T > ( & self ) -> Result < & [ T ] > {
114
+ let raw_data = self . get_raw_data ( ) ?;
115
+ let len = get_safe_len :: < T > ( raw_data) ;
116
+ let slice = unsafe { std:: slice:: from_raw_parts ( raw_data. as_ptr ( ) . cast :: < T > ( ) , len) } ;
117
+ Ok ( slice)
118
+ }
119
+
120
+ /// Get a mutable `T`-casted slice of the underlying data for the tensor.
121
+ ///
122
+ /// # Panics
123
+ ///
124
+ /// This method will panic if it can't cast the data to `T` due to the type size or the
125
+ /// underlying pointer's alignment.
126
+ pub fn get_data_mut < T > ( & mut self ) -> Result < & mut [ T ] > {
127
+ let raw_data = self . get_raw_data_mut ( ) ?;
128
+ let len = get_safe_len :: < T > ( raw_data) ;
129
+ let slice =
130
+ unsafe { std:: slice:: from_raw_parts_mut ( raw_data. as_mut_ptr ( ) . cast :: < T > ( ) , len) } ;
131
+ Ok ( slice)
132
+ }
133
+ }
134
+
135
+ /// Convenience function for checking that we can cast `data` to a slice of `T`, returning the
136
+ /// length of that slice.
137
+ fn get_safe_len < T > ( data : & [ u8 ] ) -> usize {
138
+ if data. len ( ) % std:: mem:: size_of :: < T > ( ) != 0 {
139
+ panic ! ( "data size is not a multiple of the size of `T`" ) ;
140
+ }
141
+ if data. as_ptr ( ) as usize % std:: mem:: align_of :: < T > ( ) != 0 {
142
+ panic ! ( "raw data is not aligned to `T`'s alignment" ) ;
143
+ }
144
+ data. len ( ) / std:: mem:: size_of :: < T > ( )
110
145
}
111
146
112
147
#[ cfg( test) ]
113
148
mod tests {
114
149
use super :: * ;
115
- use crate :: { ElementType , LoadingError , Shape } ;
116
150
117
151
#[ test]
118
152
fn test_create_tensor ( ) {
119
- openvino_sys:: library:: load ( )
120
- . map_err ( LoadingError :: SystemFailure )
121
- . unwrap ( ) ;
153
+ openvino_sys:: library:: load ( ) . unwrap ( ) ;
122
154
let shape = Shape :: new ( & vec ! [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ;
123
155
let tensor = Tensor :: new ( ElementType :: F32 , & shape) . unwrap ( ) ;
124
156
assert ! ( !tensor. ptr. is_null( ) ) ;
125
157
}
126
158
127
159
#[ test]
128
160
fn test_get_shape ( ) {
129
- openvino_sys:: library:: load ( )
130
- . map_err ( LoadingError :: SystemFailure )
131
- . unwrap ( ) ;
161
+ openvino_sys:: library:: load ( ) . unwrap ( ) ;
132
162
let tensor = Tensor :: new (
133
163
ElementType :: F32 ,
134
164
& Shape :: new ( & vec ! [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ,
@@ -140,9 +170,7 @@ mod tests {
140
170
141
171
#[ test]
142
172
fn test_get_element_type ( ) {
143
- openvino_sys:: library:: load ( )
144
- . map_err ( LoadingError :: SystemFailure )
145
- . unwrap ( ) ;
173
+ openvino_sys:: library:: load ( ) . unwrap ( ) ;
146
174
let tensor = Tensor :: new (
147
175
ElementType :: F32 ,
148
176
& Shape :: new ( & vec ! [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ,
@@ -154,9 +182,7 @@ mod tests {
154
182
155
183
#[ test]
156
184
fn test_get_size ( ) {
157
- openvino_sys:: library:: load ( )
158
- . map_err ( LoadingError :: SystemFailure )
159
- . unwrap ( ) ;
185
+ openvino_sys:: library:: load ( ) . unwrap ( ) ;
160
186
let tensor = Tensor :: new (
161
187
ElementType :: F32 ,
162
188
& Shape :: new ( & vec ! [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ,
@@ -168,9 +194,7 @@ mod tests {
168
194
169
195
#[ test]
170
196
fn test_get_byte_size ( ) {
171
- openvino_sys:: library:: load ( )
172
- . map_err ( LoadingError :: SystemFailure )
173
- . unwrap ( ) ;
197
+ openvino_sys:: library:: load ( ) . unwrap ( ) ;
174
198
let tensor = Tensor :: new (
175
199
ElementType :: F32 ,
176
200
& Shape :: new ( & vec ! [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ,
@@ -182,4 +206,24 @@ mod tests {
182
206
1 * 3 * 227 * 227 * std:: mem:: size_of:: <f32 >( ) as usize
183
207
) ;
184
208
}
209
+
210
+ #[ test]
211
+ fn casting ( ) {
212
+ openvino_sys:: library:: load ( ) . unwrap ( ) ;
213
+ let shape = Shape :: new ( & vec ! [ 10 , 10 , 10 ] ) . unwrap ( ) ;
214
+ let tensor = Tensor :: new ( ElementType :: F32 , & shape) . unwrap ( ) ;
215
+ let data = tensor. get_data :: < f32 > ( ) . unwrap ( ) ;
216
+ assert_eq ! ( data. len( ) , 10 * 10 * 10 ) ;
217
+ }
218
+
219
+ #[ test]
220
+ #[ should_panic( expected = "data size is not a multiple of the size of `T`" ) ]
221
+ fn casting_check ( ) {
222
+ openvino_sys:: library:: load ( ) . unwrap ( ) ;
223
+ let shape = Shape :: new ( & vec ! [ 10 , 10 , 10 ] ) . unwrap ( ) ;
224
+ let tensor = Tensor :: new ( ElementType :: F32 , & shape) . unwrap ( ) ;
225
+ #[ allow( dead_code) ]
226
+ struct LargeOddType ( [ u8 ; 1061 ] ) ;
227
+ tensor. get_data :: < LargeOddType > ( ) . unwrap ( ) ;
228
+ }
185
229
}
0 commit comments