Skip to content

Commit ea00c5f

Browse files
revert to origin mul_cos code
Signed-off-by: HU Yuan2 <[email protected]>
1 parent 14c1db1 commit ea00c5f

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX(int rank) {
179179
// so here we use a WA, only match the path of rotate_hal(x)*sin and check the x*cos path
180180
// in the callback
181181
auto x = pattern::any_input(pattern::rank_equals(rank));
182-
auto cos = pattern::any_input(pattern::rank_equals(rank));
182+
auto x_or_cos1 = pattern::any_input(pattern::rank_equals(rank));
183+
auto x_or_cos2 = pattern::any_input(pattern::rank_equals(rank));
183184
auto t_sin = pattern::any_input(pattern::rank_equals(rank));
184185

185186
auto varsplit = pattern::wrap_type<v1::VariadicSplit>({x, rank - 1, {"half_ndims", "?"}});
@@ -192,10 +193,7 @@ ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX(int rank) {
192193
auto x1 = NewGenSlice(x, 0, "half_ndims", 1, rank - 1);
193194
auto x_rotate_half = pattern::wrap_type<v0::Concat>({x2neg, x1 | varsplit->output(0)}, {{"axis", -1}});
194195

195-
auto mul_cos1 = pattern::wrap_type<v1::Multiply>({x, cos}, {{"auto_broadcast", "numpy"}});
196-
auto mul_cos2 = pattern::wrap_type<v1::Multiply>({cos, x}, {{"auto_broadcast", "numpy"}});
197-
auto mul_cos = mul_cos1 | mul_cos2;
198-
196+
auto mul_cos = pattern::wrap_type<v1::Multiply>({x_or_cos1, x_or_cos2}, {{"auto_broadcast", "numpy"}});
199197
auto mul_sin = pattern::wrap_type<v1::Multiply>({x_rotate_half, t_sin}, {{"auto_broadcast", "numpy"}});
200198

201199
auto result = pattern::wrap_type<v1::Add>({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}});
@@ -204,6 +202,17 @@ ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX(int rank) {
204202
const auto& pattern_map = m.get_pattern_value_map();
205203
auto root = m.get_match_root();
206204

205+
// check mul(x, cos) exists
206+
Output<Node> v_cos;
207+
if (pattern_map.at(x_or_cos1) == pattern_map.at(x)) {
208+
v_cos = pattern_map.at(x_or_cos2);
209+
} else if (pattern_map.at(x_or_cos2) == pattern_map.at(x)) {
210+
v_cos = pattern_map.at(x_or_cos1);
211+
} else {
212+
// not a RoPE
213+
return false;
214+
}
215+
207216
auto symbols = m.get_symbols();
208217
auto half_ndims = symbols["half_ndims"];
209218
if (!half_ndims.is_integer()) {
@@ -218,7 +227,7 @@ ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX(int rank) {
218227
config.rotary_ndims = 2ul * static_cast<size_t>(half_ndims.i());
219228

220229
new_args.push_back(pattern_map.at(x));
221-
new_args.push_back(pattern_map.at(cos));
230+
new_args.push_back(v_cos);
222231
new_args.push_back(pattern_map.at(t_sin));
223232
auto old_node = root;
224233
auto new_node = std::make_shared<internal::RoPE>(new_args, config);

0 commit comments

Comments
 (0)