@@ -179,7 +179,8 @@ ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX(int rank) {
179
179
// so here we use a WA, only match the path of rotate_hal(x)*sin and check the x*cos path
180
180
// in the callback
181
181
auto x = pattern::any_input (pattern::rank_equals (rank));
182
- auto cos = pattern::any_input (pattern::rank_equals (rank));
182
+ auto x_or_cos1 = pattern::any_input (pattern::rank_equals (rank));
183
+ auto x_or_cos2 = pattern::any_input (pattern::rank_equals (rank));
183
184
auto t_sin = pattern::any_input (pattern::rank_equals (rank));
184
185
185
186
auto varsplit = pattern::wrap_type<v1::VariadicSplit>({x, rank - 1 , {" half_ndims" , " ?" }});
@@ -192,10 +193,7 @@ ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX(int rank) {
192
193
auto x1 = NewGenSlice (x, 0 , " half_ndims" , 1 , rank - 1 );
193
194
auto x_rotate_half = pattern::wrap_type<v0::Concat>({x2neg, x1 | varsplit->output (0 )}, {{" axis" , -1 }});
194
195
195
- auto mul_cos1 = pattern::wrap_type<v1::Multiply>({x, cos}, {{" auto_broadcast" , " numpy" }});
196
- auto mul_cos2 = pattern::wrap_type<v1::Multiply>({cos, x}, {{" auto_broadcast" , " numpy" }});
197
- auto mul_cos = mul_cos1 | mul_cos2;
198
-
196
+ auto mul_cos = pattern::wrap_type<v1::Multiply>({x_or_cos1, x_or_cos2}, {{" auto_broadcast" , " numpy" }});
199
197
auto mul_sin = pattern::wrap_type<v1::Multiply>({x_rotate_half, t_sin}, {{" auto_broadcast" , " numpy" }});
200
198
201
199
auto result = pattern::wrap_type<v1::Add>({mul_cos, mul_sin}, {{" auto_broadcast" , " numpy" }});
@@ -204,6 +202,17 @@ ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX(int rank) {
204
202
const auto & pattern_map = m.get_pattern_value_map ();
205
203
auto root = m.get_match_root ();
206
204
205
+ // check mul(x, cos) exists
206
+ Output<Node> v_cos;
207
+ if (pattern_map.at (x_or_cos1) == pattern_map.at (x)) {
208
+ v_cos = pattern_map.at (x_or_cos2);
209
+ } else if (pattern_map.at (x_or_cos2) == pattern_map.at (x)) {
210
+ v_cos = pattern_map.at (x_or_cos1);
211
+ } else {
212
+ // not a RoPE
213
+ return false ;
214
+ }
215
+
207
216
auto symbols = m.get_symbols ();
208
217
auto half_ndims = symbols[" half_ndims" ];
209
218
if (!half_ndims.is_integer ()) {
@@ -218,7 +227,7 @@ ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX(int rank) {
218
227
config.rotary_ndims = 2ul * static_cast <size_t >(half_ndims.i ());
219
228
220
229
new_args.push_back (pattern_map.at (x));
221
- new_args.push_back (pattern_map. at (cos) );
230
+ new_args.push_back (v_cos );
222
231
new_args.push_back (pattern_map.at (t_sin));
223
232
auto old_node = root;
224
233
auto new_node = std::make_shared<internal::RoPE>(new_args, config);
0 commit comments