@@ -308,7 +308,7 @@ static std::shared_ptr<ov::Model> buildROPE_GPTNEOX(const int batch,
308
308
static std::shared_ptr<ov::Model> buildROPE_VIT (const int seq_length,
309
309
const int num_heads,
310
310
const int rotary_ndims,
311
- bool is_split ) {
311
+ std::string split_op_type ) {
312
312
auto seq_length_s = static_cast <size_t >(seq_length);
313
313
auto rotary_ndims_s = static_cast <size_t >(rotary_ndims);
314
314
auto num_heads_s = static_cast <size_t >(num_heads);
@@ -321,17 +321,39 @@ static std::shared_ptr<ov::Model> buildROPE_VIT(const int seq_length,
321
321
auto param_sin =
322
322
std::make_shared<ov::opset1::Parameter>(ov::element::f32 , ov::Shape{seq_length_s, 1 , rotary_ndims_s});
323
323
ov::Output<ov::Node> cat_Concat;
324
- if (is_split ) {
324
+ if (split_op_type == " VariadicSplit " ) {
325
325
auto split = makeOP<ov::opset1::VariadicSplit>({input, {2 }, {rotary_ndims / 2 , rotary_ndims / 2 }});
326
326
auto neg_Multiply =
327
327
makeOP<ov::opset1::Multiply>({split->output (1 ), Constant_396096}, {{" auto_broadcast" , " numpy" }});
328
328
cat_Concat = makeOP<ov::opset1::Concat>({neg_Multiply, split->output (0 )}, {{" axis" , -1 }});
329
- } else {
329
+ } else if (split_op_type == " Slice " ) {
330
330
auto slice_right_part = makeOP<ov::opset8::Slice>({input, {rotary_ndims / 2 }, {INT_MAX}, {1 }, {2 }});
331
331
auto slice_left_part = makeOP<ov::opset8::Slice>({input, {0 }, {rotary_ndims / 2 }, {1 }, {2 }});
332
332
auto neg_Multiply =
333
333
makeOP<ov::opset1::Multiply>({slice_right_part, Constant_396096}, {{" auto_broadcast" , " numpy" }});
334
334
cat_Concat = makeOP<ov::opset1::Concat>({neg_Multiply, slice_left_part}, {{" axis" , -1 }});
335
+ } else if (split_op_type == " StridedSlice" ) {
336
+ auto slice_right_part = makeOP<ov::opset1::StridedSlice>({input, {0 , 0 , rotary_ndims / 2 },
337
+ {0 , 0 , INT_MAX},
338
+ {1 , 1 , 1 }},
339
+ {{" begin_mask" , {1 , 1 , 0 }},
340
+ {" end_mask" , {1 , 1 , 0 }},
341
+ {" new_axis_mask" , {}},
342
+ {" shrink_axis_mask" , {}},
343
+ {" ellipsis_mask" , {}}});
344
+ auto slice_left_part = makeOP<ov::opset1::StridedSlice>({input, {0 , 0 , 0 },
345
+ {0 , 0 , rotary_ndims / 2 },
346
+ {1 , 1 , 1 }},
347
+ {{" begin_mask" , {1 , 1 , 0 }},
348
+ {" end_mask" , {1 , 1 , 0 }},
349
+ {" new_axis_mask" , {}},
350
+ {" shrink_axis_mask" , {}},
351
+ {" ellipsis_mask" , {}}});
352
+ auto neg_Multiply =
353
+ makeOP<ov::opset1::Multiply>({slice_right_part, Constant_396096}, {{" auto_broadcast" , " numpy" }});
354
+ cat_Concat = makeOP<ov::opset1::Concat>({neg_Multiply, slice_left_part}, {{" axis" , -1 }});
355
+ } else {
356
+ return nullptr ;
335
357
}
336
358
auto mul_sin_Multiply = makeOP<ov::opset1::Multiply>({cat_Concat, param_sin}, {{" auto_broadcast" , " numpy" }});
337
359
auto mul_cos_Multiply = makeOP<ov::opset1::Multiply>({input, param_cos}, {{" auto_broadcast" , " numpy" }});
@@ -744,14 +766,22 @@ TEST_P(ConvertToROPETest, ConvertToROPE_chatGLM_Slice) {
744
766
745
767
INSTANTIATE_TEST_SUITE_P (TransformationTestsF, ConvertToROPETest, ::testing::ValuesIn({0 , 1 }));
746
768
747
- class ConvertToROPETestVIT : public TransformationTestsF , public ::testing::WithParamInterface<bool > {};
769
+ class ConvertToROPETestVIT : public TransformationTestsF , public ::testing::WithParamInterface<std::string> {
770
+ public:
771
+ static std::string getTestCaseName (const testing::TestParamInfo<std::string>& obj) {
772
+ const auto & split_op_type = obj.param ;
773
+ std::ostringstream result;
774
+ result << " split_op_type=" << split_op_type;
775
+ return result.str ();
776
+ }
777
+ };
748
778
TEST_P (ConvertToROPETestVIT, ConvertToROPE_qwen) {
749
779
disable_rt_info_check ();
750
780
const int seq_len = 16 ;
751
781
const int num_heads = 32 ;
752
782
const int rotary_ndims = 80 ;
753
- const int is_split = GetParam ();
754
- model = buildROPE_VIT (seq_len, num_heads, rotary_ndims, is_split );
783
+ const std::string split_op_type = GetParam ();
784
+ model = buildROPE_VIT (seq_len, num_heads, rotary_ndims, split_op_type );
755
785
manager.register_pass <ov::pass::RoPEFusionVIT3D>();
756
786
{
757
787
auto input =
@@ -778,7 +808,11 @@ TEST_P(ConvertToROPETestVIT, ConvertToROPE_qwen) {
778
808
}
779
809
}
780
810
781
- INSTANTIATE_TEST_SUITE_P (TransformationTestsF, ConvertToROPETestVIT, ::testing::ValuesIn({false , true }));
811
+ const std::vector<std::string> vit_param = {" VariadicSplit" , " Slice" , " StridedSlice" };
812
+ INSTANTIATE_TEST_SUITE_P (TransformationTestsF,
813
+ ConvertToROPETestVIT,
814
+ ::testing::ValuesIn (vit_param),
815
+ ConvertToROPETestVIT::getTestCaseName);
782
816
783
817
TEST_F (TransformationTestsF, ConvertToROPE_GPTJ_Slice) {
784
818
disable_rt_info_check ();
0 commit comments