Skip to content

Commit 22cdecd

Browse files
add more test for tansformation
Signed-off-by: HU Yuan2 <[email protected]>
1 parent 324cb6a commit 22cdecd

File tree

2 files changed

+78
-8
lines changed

2 files changed

+78
-8
lines changed

src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ static std::shared_ptr<ov::Model> buildROPE_GPTNEOX(const int batch,
308308
static std::shared_ptr<ov::Model> buildROPE_VIT(const int seq_length,
309309
const int num_heads,
310310
const int rotary_ndims,
311-
bool is_split) {
311+
std::string split_op_type) {
312312
auto seq_length_s = static_cast<size_t>(seq_length);
313313
auto rotary_ndims_s = static_cast<size_t>(rotary_ndims);
314314
auto num_heads_s = static_cast<size_t>(num_heads);
@@ -321,17 +321,39 @@ static std::shared_ptr<ov::Model> buildROPE_VIT(const int seq_length,
321321
auto param_sin =
322322
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{seq_length_s, 1, rotary_ndims_s});
323323
ov::Output<ov::Node> cat_Concat;
324-
if (is_split) {
324+
if (split_op_type == "VariadicSplit") {
325325
auto split = makeOP<ov::opset1::VariadicSplit>({input, {2}, {rotary_ndims / 2, rotary_ndims / 2}});
326326
auto neg_Multiply =
327327
makeOP<ov::opset1::Multiply>({split->output(1), Constant_396096}, {{"auto_broadcast", "numpy"}});
328328
cat_Concat = makeOP<ov::opset1::Concat>({neg_Multiply, split->output(0)}, {{"axis", -1}});
329-
} else {
329+
} else if (split_op_type == "Slice") {
330330
auto slice_right_part = makeOP<ov::opset8::Slice>({input, {rotary_ndims / 2}, {INT_MAX}, {1}, {2}});
331331
auto slice_left_part = makeOP<ov::opset8::Slice>({input, {0}, {rotary_ndims / 2}, {1}, {2}});
332332
auto neg_Multiply =
333333
makeOP<ov::opset1::Multiply>({slice_right_part, Constant_396096}, {{"auto_broadcast", "numpy"}});
334334
cat_Concat = makeOP<ov::opset1::Concat>({neg_Multiply, slice_left_part}, {{"axis", -1}});
335+
} else if (split_op_type == "StridedSlice") {
336+
auto slice_right_part = makeOP<ov::opset1::StridedSlice>({input, {0, 0, rotary_ndims / 2},
337+
{0, 0, INT_MAX},
338+
{1, 1, 1}},
339+
{{"begin_mask", {1, 1, 0}},
340+
{"end_mask", {1, 1, 0}},
341+
{"new_axis_mask", {}},
342+
{"shrink_axis_mask", {}},
343+
{"ellipsis_mask", {}}});
344+
auto slice_left_part = makeOP<ov::opset1::StridedSlice>({input, {0, 0, 0},
345+
{0, 0, rotary_ndims / 2},
346+
{1, 1, 1}},
347+
{{"begin_mask", {1, 1, 0}},
348+
{"end_mask", {1, 1, 0}},
349+
{"new_axis_mask", {}},
350+
{"shrink_axis_mask", {}},
351+
{"ellipsis_mask", {}}});
352+
auto neg_Multiply =
353+
makeOP<ov::opset1::Multiply>({slice_right_part, Constant_396096}, {{"auto_broadcast", "numpy"}});
354+
cat_Concat = makeOP<ov::opset1::Concat>({neg_Multiply, slice_left_part}, {{"axis", -1}});
355+
} else {
356+
return nullptr;
335357
}
336358
auto mul_sin_Multiply = makeOP<ov::opset1::Multiply>({cat_Concat, param_sin}, {{"auto_broadcast", "numpy"}});
337359
auto mul_cos_Multiply = makeOP<ov::opset1::Multiply>({input, param_cos}, {{"auto_broadcast", "numpy"}});
@@ -744,14 +766,22 @@ TEST_P(ConvertToROPETest, ConvertToROPE_chatGLM_Slice) {
744766

745767
INSTANTIATE_TEST_SUITE_P(TransformationTestsF, ConvertToROPETest, ::testing::ValuesIn({0, 1}));
746768

747-
class ConvertToROPETestVIT : public TransformationTestsF, public ::testing::WithParamInterface<bool> {};
769+
class ConvertToROPETestVIT : public TransformationTestsF, public ::testing::WithParamInterface<std::string> {
770+
public:
771+
static std::string getTestCaseName(const testing::TestParamInfo<std::string>& obj) {
772+
const auto& split_op_type = obj.param;
773+
std::ostringstream result;
774+
result << "split_op_type=" << split_op_type;
775+
return result.str();
776+
}
777+
};
748778
TEST_P(ConvertToROPETestVIT, ConvertToROPE_qwen) {
749779
disable_rt_info_check();
750780
const int seq_len = 16;
751781
const int num_heads = 32;
752782
const int rotary_ndims = 80;
753-
const int is_split = GetParam();
754-
model = buildROPE_VIT(seq_len, num_heads, rotary_ndims, is_split);
783+
const std::string split_op_type = GetParam();
784+
model = buildROPE_VIT(seq_len, num_heads, rotary_ndims, split_op_type);
755785
manager.register_pass<ov::pass::RoPEFusionVIT3D>();
756786
{
757787
auto input =
@@ -778,7 +808,11 @@ TEST_P(ConvertToROPETestVIT, ConvertToROPE_qwen) {
778808
}
779809
}
780810

781-
INSTANTIATE_TEST_SUITE_P(TransformationTestsF, ConvertToROPETestVIT, ::testing::ValuesIn({false, true}));
811+
const std::vector<std::string> vit_param = {"VariadicSplit", "Slice", "StridedSlice"};
812+
INSTANTIATE_TEST_SUITE_P(TransformationTestsF,
813+
ConvertToROPETestVIT,
814+
::testing::ValuesIn(vit_param),
815+
ConvertToROPETestVIT::getTestCaseName);
782816

783817
TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_Slice) {
784818
disable_rt_info_check();

src/tests/functional/plugin/shared/src/subgraph/rotary_pos_emb.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,47 @@ std::string RoPETestFlux::getTestCaseName(const testing::TestParamInfo<rope_para
117117
std::shared_ptr<ov::Model> RoPETestQwenVL::buildROPE_QwenVL(ov::element::Type element_type,
118118
ov::PartialShape input_shape,
119119
ov::PartialShape cos_shape,
120-
ov::PartialShape sin_shape) {
120+
ov::PartialShape sin_shape,
121+
std::string split_op_type) {
121122
auto input = std::make_shared<opset1::Parameter>(element_type, input_shape);
122123
auto cos_cache = std::make_shared<opset1::Parameter>(element_type, cos_shape);
123124
auto cos_mul = makeOP<opset1::Multiply>({input, cos_cache}, {{"auto_broadcast", "numpy"}});
124125
auto sin_cache = std::make_shared<opset1::Parameter>(element_type, sin_shape);
126+
if (split_op_type == "VariadicSplit") {
127+
auto split = makeOP<ov::opset1::VariadicSplit>({input, {2}, {rotary_ndims / 2, rotary_ndims / 2}});
128+
auto neg_Multiply =
129+
makeOP<ov::opset1::Multiply>({split->output(1), Constant_396096}, {{"auto_broadcast", "numpy"}});
130+
cat_Concat = makeOP<ov::opset1::Concat>({neg_Multiply, split->output(0)}, {{"axis", -1}});
131+
} else if (split_op_type == "Slice") {
132+
auto slice_right_part = makeOP<ov::opset8::Slice>({input, {rotary_ndims / 2}, {INT_MAX}, {1}, {2}});
133+
auto slice_left_part = makeOP<ov::opset8::Slice>({input, {0}, {rotary_ndims / 2}, {1}, {2}});
134+
auto neg_Multiply =
135+
makeOP<ov::opset1::Multiply>({slice_right_part, Constant_396096}, {{"auto_broadcast", "numpy"}});
136+
cat_Concat = makeOP<ov::opset1::Concat>({neg_Multiply, slice_left_part}, {{"axis", -1}});
137+
} else if (split_op_type == "StridedSlice") {
138+
auto slice_right_part = makeOP<ov::opset1::StridedSlice>({input, {0, 0, rotary_ndims / 2},
139+
{0, 0, INT_MAX},
140+
{1, 1, 1}},
141+
{{"begin_mask", {1, 1, 0}},
142+
{"end_mask", {1, 1, 0}},
143+
{"new_axis_mask", {}},
144+
{"shrink_axis_mask", {}},
145+
{"ellipsis_mask", {}}});
146+
auto slice_left_part = makeOP<ov::opset1::StridedSlice>({input, {0, 0, 0},
147+
{0, 0, rotary_ndims / 2},
148+
{1, 1, 1}},
149+
{{"begin_mask", {1, 1, 0}},
150+
{"end_mask", {1, 1, 0}},
151+
{"new_axis_mask", {}},
152+
{"shrink_axis_mask", {}},
153+
{"ellipsis_mask", {}}});
154+
auto neg_Multiply =
155+
makeOP<ov::opset1::Multiply>({slice_right_part, Constant_396096}, {{"auto_broadcast", "numpy"}});
156+
cat_Concat = makeOP<ov::opset1::Concat>({neg_Multiply, slice_left_part}, {{"axis", -1}});
157+
} else {
158+
return nullptr;
159+
}
160+
125161
auto input_slice = makeOP<opset1::StridedSlice>({input, {0, 0, 40}, {0, 0, INT_MAX}, {1, 1, 1}},
126162
{{"begin_mask", {1, 1, 0}},
127163
{"end_mask", {1, 1, 0}},

0 commit comments

Comments
 (0)