Skip to content

Commit 3da5bce

Browse files
committed
add new pattern for eff attn
1 parent 5431e8e commit 3da5bce

File tree

1 file changed

+159
-3
lines changed
  • orttraining/orttraining/python/training/ortmodule/graph_optimizers

1 file changed

+159
-3
lines changed

orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py

Lines changed: 159 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ def _make_efficient_attention_nodes(
3434
expand_bias: bool,
3535
scale: float,
3636
dropout_ratio: float,
37+
causal: bool,
3738
):
3839
nodes_to_add = []
3940
scale_node = make_constant_node("scale_" + str(idx), TensorProto.FLOAT, [], [scale])
4041
dropout_ratio_node = make_constant_node("dropout_ratio_" + str(idx), TensorProto.FLOAT, [], [dropout_ratio])
42+
causal_node = make_constant_node("causal_" + str(idx), TensorProto.INT64, [], [1 if causal else 0])
4143
int_zero_node = make_constant_node("int_zero_" + str(idx), TensorProto.INT64, [], [0])
4244
true_node = make_constant_node("true_" + str(idx), TensorProto.BOOL, [], [True])
4345
false_node = make_constant_node("false_" + str(idx), TensorProto.BOOL, [], [False])
@@ -70,7 +72,7 @@ def _make_efficient_attention_nodes(
7072
"",
7173
"",
7274
dropout_ratio_node.output[0],
73-
int_zero_node.output[0],
75+
causal_node.output[0],
7476
true_node.output[0],
7577
scale_node.output[0],
7678
"",
@@ -99,7 +101,7 @@ def _make_efficient_attention_nodes(
99101
dropout_ratio_node.output[0],
100102
seed.name,
101103
offset.name,
102-
int_zero_node.output[0],
104+
causal_node.output[0],
103105
false_node.output[0],
104106
scale_node.output[0],
105107
"",
@@ -110,7 +112,9 @@ def _make_efficient_attention_nodes(
110112
"org.pytorch.aten",
111113
operator="_efficient_attention_backward",
112114
)
113-
nodes_to_add.extend([scale_node, dropout_ratio_node, int_zero_node, true_node, false_node, fwd_node, bwd_node])
115+
nodes_to_add.extend(
116+
[scale_node, dropout_ratio_node, causal_node, int_zero_node, true_node, false_node, fwd_node, bwd_node]
117+
)
114118
return nodes_to_add, new_value_infos
115119

116120

@@ -172,6 +176,7 @@ def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodePro
172176
add_input_shape_0 != add_input_shape_1,
173177
1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value),
174178
ratio_value,
179+
False,
175180
)
176181
return nodes, nodes_to_add, new_value_infos
177182

@@ -230,13 +235,164 @@ def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodePro
230235
add_input_shape_0 != add_input_shape_1,
231236
1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value),
232237
0.0,
238+
False,
239+
)
240+
return nodes, nodes_to_add, new_value_infos
241+
242+
243+
# No causal mask, no attention mask, without Dropout.
244+
_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
245+
("MatMul", False, []), # 0
246+
("Mul", True, [(0, 0, 0)]), # 1
247+
("Mul", True, [(0, 0, 1)]), # 2
248+
("Cast", True, [(1, 0, 0)]), # 3
249+
("Cast", True, [(2, 0, 0)]), # 4
250+
("Transpose", True, [(3, 0, 0)]), # 5
251+
("Transpose", True, [(4, 0, 0)]), # 6
252+
("Softmax", False, [(0, 0, 0)]), # 7
253+
("Cast", False, [(7, 0, 0)]), # 8
254+
("MatMul", False, [(8, 0, 0)]), # 9
255+
("Transpose", True, [(9, 0, 1)]), # 10
256+
("Transpose", False, [(9, 0, 0)]), # 11
257+
("FusedMatMul", False, [(10, 0, 1)]), # 12
258+
("Cast", False, [(12, 0, 0)]), # 13
259+
("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14
260+
("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15
261+
("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16
262+
("Mul", False, [(15, 0, 0)]), # 17
263+
("Mul", False, [(16, 0, 0)]), # 18
264+
("Identity", False, [(17, 0, 0)]), # 19
265+
("Identity", False, [(18, 0, 0)]), # 20
266+
("Cast", False, [(19, 0, 0)]), # 21
267+
("Cast", False, [(20, 0, 0)]), # 22
268+
("Transpose", False, [(21, 0, 0)]), # 23
269+
("Transpose", False, [(22, 0, 0)]), # 24
270+
("FusedMatMul", False, [(8, 0, 0)]), # 25
271+
("Transpose", True, [(25, 0, 1)]), # 26
272+
("Transpose", False, [(25, 0, 0)]), # 27
273+
]
274+
275+
276+
def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
277+
# Check forward only as the backward is expected to be consistent if it's built correctly.
278+
scale_value_1 = matcher.get_constant_value(nodes[1].input[1])
279+
scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1
280+
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
281+
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
282+
if not (
283+
check_attribute_value(nodes[3], "to", 1)
284+
and check_attribute_value(nodes[4], "to", 1)
285+
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
286+
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
287+
and check_attribute_value(nodes[8], "to", 10)
288+
and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3])
289+
and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3])
290+
and scale_value_1 == scale_value_2
291+
):
292+
return [], [], []
293+
294+
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
295+
idx,
296+
nodes[5].input[0],
297+
nodes[6].input[0],
298+
nodes[10].input[0],
299+
nodes[11].output[0],
300+
nodes[26].input[0],
301+
nodes[23].output[0],
302+
nodes[24].output[0],
303+
nodes[27].output[0],
304+
"",
305+
False,
306+
scale_value_1,
307+
0.0,
308+
False,
309+
)
310+
return nodes, nodes_to_add, new_value_infos
311+
312+
313+
# Has causal mask, no attention mask, without Dropout.
314+
_PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
315+
("MatMul", False, []), # 0
316+
("Mul", True, [(0, 0, 0)]), # 1
317+
("Mul", True, [(0, 0, 1)]), # 2
318+
("Cast", True, [(1, 0, 0)]), # 3
319+
("Cast", True, [(2, 0, 0)]), # 4
320+
("Transpose", True, [(3, 0, 0)]), # 5
321+
("Transpose", True, [(4, 0, 0)]), # 6
322+
("Add", False, [(0, 0, 0)]), # 7
323+
("Cast", True, [(7, 0, 1)]), # 8
324+
("Slice", True, [(8, 0, 0)]), # 9
325+
("Slice", True, [(9, 0, 0)]), # 10
326+
("Unsqueeze", True, [(9, 0, 2)]), # 11
327+
("Gather", True, [(11, 0, 0)]), # 12
328+
("Shape", True, [(12, 0, 0)]), # 13
329+
("Softmax", False, [(7, 0, 0)]), # 14
330+
("Cast", False, [(14, 0, 0)]), # 15
331+
("MatMul", False, [(15, 0, 0)]), # 16
332+
("Transpose", True, [(16, 0, 1)]), # 17
333+
("Transpose", False, [(16, 0, 0)]), # 18
334+
("FusedMatMul", False, [(17, 0, 1)]), # 19
335+
("Cast", False, [(19, 0, 0)]), # 20
336+
("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21
337+
("Identity", False, [(21, 0, 0)]), # 22
338+
("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23
339+
("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24
340+
("Mul", False, [(23, 0, 0)]), # 25
341+
("Mul", False, [(24, 0, 0)]), # 26
342+
("Identity", False, [(25, 0, 0)]), # 27
343+
("Identity", False, [(26, 0, 0)]), # 28
344+
("Cast", False, [(27, 0, 0)]), # 29
345+
("Cast", False, [(28, 0, 0)]), # 30
346+
("Transpose", False, [(29, 0, 0)]), # 31
347+
("Transpose", False, [(30, 0, 0)]), # 32
348+
("FusedMatMul", False, [(15, 0, 0)]), # 33
349+
("Transpose", True, [(33, 0, 1)]), # 34
350+
("Transpose", False, [(33, 0, 0)]), # 35
351+
]
352+
353+
354+
def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
355+
# Check forward only as the backward is expected to be consistent if it's built correctly.
356+
scale_value_1 = matcher.get_constant_value(nodes[1].input[1])
357+
scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1
358+
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
359+
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
360+
if not (
361+
check_attribute_value(nodes[3], "to", 1)
362+
and check_attribute_value(nodes[4], "to", 1)
363+
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
364+
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
365+
and check_attribute_value(nodes[15], "to", 10)
366+
and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3])
367+
and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3])
368+
and scale_value_1 == scale_value_2
369+
):
370+
return [], [], []
371+
372+
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
373+
idx,
374+
nodes[5].input[0],
375+
nodes[6].input[0],
376+
nodes[17].input[0],
377+
nodes[18].output[0],
378+
nodes[34].input[0],
379+
nodes[31].output[0],
380+
nodes[32].output[0],
381+
nodes[35].output[0],
382+
"",
383+
False,
384+
scale_value_1,
385+
0.0,
386+
True,
233387
)
234388
return nodes, nodes_to_add, new_value_infos
235389

236390

237391
_PATTERNS = [
238392
(_PATTERN_0, _optimize_for_pattern_0),
239393
(_PATTERN_1, _optimize_for_pattern_1),
394+
(_PATTERN_2, _optimize_for_pattern_2),
395+
(_PATTERN_3, _optimize_for_pattern_3),
240396
]
241397

242398

0 commit comments

Comments
 (0)