@@ -34,10 +34,12 @@ def _make_efficient_attention_nodes(
34
34
expand_bias : bool ,
35
35
scale : float ,
36
36
dropout_ratio : float ,
37
+ causal : bool ,
37
38
):
38
39
nodes_to_add = []
39
40
scale_node = make_constant_node ("scale_" + str (idx ), TensorProto .FLOAT , [], [scale ])
40
41
dropout_ratio_node = make_constant_node ("dropout_ratio_" + str (idx ), TensorProto .FLOAT , [], [dropout_ratio ])
42
+ causal_node = make_constant_node ("causal_" + str (idx ), TensorProto .INT64 , [], [1 if causal else 0 ])
41
43
int_zero_node = make_constant_node ("int_zero_" + str (idx ), TensorProto .INT64 , [], [0 ])
42
44
true_node = make_constant_node ("true_" + str (idx ), TensorProto .BOOL , [], [True ])
43
45
false_node = make_constant_node ("false_" + str (idx ), TensorProto .BOOL , [], [False ])
@@ -70,7 +72,7 @@ def _make_efficient_attention_nodes(
70
72
"" ,
71
73
"" ,
72
74
dropout_ratio_node .output [0 ],
73
- int_zero_node .output [0 ],
75
+ causal_node .output [0 ],
74
76
true_node .output [0 ],
75
77
scale_node .output [0 ],
76
78
"" ,
@@ -99,7 +101,7 @@ def _make_efficient_attention_nodes(
99
101
dropout_ratio_node .output [0 ],
100
102
seed .name ,
101
103
offset .name ,
102
- int_zero_node .output [0 ],
104
+ causal_node .output [0 ],
103
105
false_node .output [0 ],
104
106
scale_node .output [0 ],
105
107
"" ,
@@ -110,7 +112,9 @@ def _make_efficient_attention_nodes(
110
112
"org.pytorch.aten" ,
111
113
operator = "_efficient_attention_backward" ,
112
114
)
113
- nodes_to_add .extend ([scale_node , dropout_ratio_node , int_zero_node , true_node , false_node , fwd_node , bwd_node ])
115
+ nodes_to_add .extend (
116
+ [scale_node , dropout_ratio_node , causal_node , int_zero_node , true_node , false_node , fwd_node , bwd_node ]
117
+ )
114
118
return nodes_to_add , new_value_infos
115
119
116
120
@@ -172,6 +176,7 @@ def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodePro
172
176
add_input_shape_0 != add_input_shape_1 ,
173
177
1 / float (scale_value [0 ] if isinstance (scale_value , list ) else scale_value ),
174
178
ratio_value ,
179
+ False ,
175
180
)
176
181
return nodes , nodes_to_add , new_value_infos
177
182
@@ -230,13 +235,164 @@ def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodePro
230
235
add_input_shape_0 != add_input_shape_1 ,
231
236
1 / float (scale_value [0 ] if isinstance (scale_value , list ) else scale_value ),
232
237
0.0 ,
238
+ False ,
239
+ )
240
+ return nodes , nodes_to_add , new_value_infos
241
+
242
+
243
+ # No causal mask, no attention mask, without Dropout.
244
+ _PATTERN_2 : List [Tuple [str , bool , List [Tuple [int , int , int ]]]] = [
245
+ ("MatMul" , False , []), # 0
246
+ ("Mul" , True , [(0 , 0 , 0 )]), # 1
247
+ ("Mul" , True , [(0 , 0 , 1 )]), # 2
248
+ ("Cast" , True , [(1 , 0 , 0 )]), # 3
249
+ ("Cast" , True , [(2 , 0 , 0 )]), # 4
250
+ ("Transpose" , True , [(3 , 0 , 0 )]), # 5
251
+ ("Transpose" , True , [(4 , 0 , 0 )]), # 6
252
+ ("Softmax" , False , [(0 , 0 , 0 )]), # 7
253
+ ("Cast" , False , [(7 , 0 , 0 )]), # 8
254
+ ("MatMul" , False , [(8 , 0 , 0 )]), # 9
255
+ ("Transpose" , True , [(9 , 0 , 1 )]), # 10
256
+ ("Transpose" , False , [(9 , 0 , 0 )]), # 11
257
+ ("FusedMatMul" , False , [(10 , 0 , 1 )]), # 12
258
+ ("Cast" , False , [(12 , 0 , 0 )]), # 13
259
+ ("SoftmaxGrad_13" , False , [(13 , 0 , 0 ), (7 , 0 , 1 )]), # 14
260
+ ("FusedMatMul" , False , [(2 , 0 , 1 ), (14 , 0 , 0 )]), # 15
261
+ ("FusedMatMul" , False , [(1 , 0 , 0 ), (14 , 0 , 1 )]), # 16
262
+ ("Mul" , False , [(15 , 0 , 0 )]), # 17
263
+ ("Mul" , False , [(16 , 0 , 0 )]), # 18
264
+ ("Identity" , False , [(17 , 0 , 0 )]), # 19
265
+ ("Identity" , False , [(18 , 0 , 0 )]), # 20
266
+ ("Cast" , False , [(19 , 0 , 0 )]), # 21
267
+ ("Cast" , False , [(20 , 0 , 0 )]), # 22
268
+ ("Transpose" , False , [(21 , 0 , 0 )]), # 23
269
+ ("Transpose" , False , [(22 , 0 , 0 )]), # 24
270
+ ("FusedMatMul" , False , [(8 , 0 , 0 )]), # 25
271
+ ("Transpose" , True , [(25 , 0 , 1 )]), # 26
272
+ ("Transpose" , False , [(25 , 0 , 0 )]), # 27
273
+ ]
274
+
275
+
276
+ def _optimize_for_pattern_2 (matcher : GraphMatcher , idx : int , nodes : List [NodeProto ]):
277
+ # Check forward only as the backward is expected to be consistent if it's built correctly.
278
+ scale_value_1 = matcher .get_constant_value (nodes [1 ].input [1 ])
279
+ scale_value_1 = scale_value_1 [0 ] if isinstance (scale_value_1 , list ) else scale_value_1
280
+ scale_value_2 = matcher .get_constant_value (nodes [2 ].input [1 ])
281
+ scale_value_2 = scale_value_2 [0 ] if isinstance (scale_value_2 , list ) else scale_value_2
282
+ if not (
283
+ check_attribute_value (nodes [3 ], "to" , 1 )
284
+ and check_attribute_value (nodes [4 ], "to" , 1 )
285
+ and check_attribute_value (nodes [5 ], "perm" , [0 , 2 , 1 , 3 ])
286
+ and check_attribute_value (nodes [6 ], "perm" , [0 , 2 , 3 , 1 ])
287
+ and check_attribute_value (nodes [8 ], "to" , 10 )
288
+ and check_attribute_value (nodes [10 ], "perm" , [0 , 2 , 1 , 3 ])
289
+ and check_attribute_value (nodes [11 ], "perm" , [0 , 2 , 1 , 3 ])
290
+ and scale_value_1 == scale_value_2
291
+ ):
292
+ return [], [], []
293
+
294
+ nodes_to_add , new_value_infos = _make_efficient_attention_nodes (
295
+ idx ,
296
+ nodes [5 ].input [0 ],
297
+ nodes [6 ].input [0 ],
298
+ nodes [10 ].input [0 ],
299
+ nodes [11 ].output [0 ],
300
+ nodes [26 ].input [0 ],
301
+ nodes [23 ].output [0 ],
302
+ nodes [24 ].output [0 ],
303
+ nodes [27 ].output [0 ],
304
+ "" ,
305
+ False ,
306
+ scale_value_1 ,
307
+ 0.0 ,
308
+ False ,
309
+ )
310
+ return nodes , nodes_to_add , new_value_infos
311
+
312
+
313
+ # Has causal mask, no attention mask, without Dropout.
314
+ _PATTERN_3 : List [Tuple [str , bool , List [Tuple [int , int , int ]]]] = [
315
+ ("MatMul" , False , []), # 0
316
+ ("Mul" , True , [(0 , 0 , 0 )]), # 1
317
+ ("Mul" , True , [(0 , 0 , 1 )]), # 2
318
+ ("Cast" , True , [(1 , 0 , 0 )]), # 3
319
+ ("Cast" , True , [(2 , 0 , 0 )]), # 4
320
+ ("Transpose" , True , [(3 , 0 , 0 )]), # 5
321
+ ("Transpose" , True , [(4 , 0 , 0 )]), # 6
322
+ ("Add" , False , [(0 , 0 , 0 )]), # 7
323
+ ("Cast" , True , [(7 , 0 , 1 )]), # 8
324
+ ("Slice" , True , [(8 , 0 , 0 )]), # 9
325
+ ("Slice" , True , [(9 , 0 , 0 )]), # 10
326
+ ("Unsqueeze" , True , [(9 , 0 , 2 )]), # 11
327
+ ("Gather" , True , [(11 , 0 , 0 )]), # 12
328
+ ("Shape" , True , [(12 , 0 , 0 )]), # 13
329
+ ("Softmax" , False , [(7 , 0 , 0 )]), # 14
330
+ ("Cast" , False , [(14 , 0 , 0 )]), # 15
331
+ ("MatMul" , False , [(15 , 0 , 0 )]), # 16
332
+ ("Transpose" , True , [(16 , 0 , 1 )]), # 17
333
+ ("Transpose" , False , [(16 , 0 , 0 )]), # 18
334
+ ("FusedMatMul" , False , [(17 , 0 , 1 )]), # 19
335
+ ("Cast" , False , [(19 , 0 , 0 )]), # 20
336
+ ("SoftmaxGrad_13" , False , [(20 , 0 , 0 ), (14 , 0 , 1 )]), # 21
337
+ ("Identity" , False , [(21 , 0 , 0 )]), # 22
338
+ ("FusedMatMul" , False , [(2 , 0 , 1 ), (22 , 0 , 0 )]), # 23
339
+ ("FusedMatMul" , False , [(1 , 0 , 0 ), (22 , 0 , 1 )]), # 24
340
+ ("Mul" , False , [(23 , 0 , 0 )]), # 25
341
+ ("Mul" , False , [(24 , 0 , 0 )]), # 26
342
+ ("Identity" , False , [(25 , 0 , 0 )]), # 27
343
+ ("Identity" , False , [(26 , 0 , 0 )]), # 28
344
+ ("Cast" , False , [(27 , 0 , 0 )]), # 29
345
+ ("Cast" , False , [(28 , 0 , 0 )]), # 30
346
+ ("Transpose" , False , [(29 , 0 , 0 )]), # 31
347
+ ("Transpose" , False , [(30 , 0 , 0 )]), # 32
348
+ ("FusedMatMul" , False , [(15 , 0 , 0 )]), # 33
349
+ ("Transpose" , True , [(33 , 0 , 1 )]), # 34
350
+ ("Transpose" , False , [(33 , 0 , 0 )]), # 35
351
+ ]
352
+
353
+
354
+ def _optimize_for_pattern_3 (matcher : GraphMatcher , idx : int , nodes : List [NodeProto ]):
355
+ # Check forward only as the backward is expected to be consistent if it's built correctly.
356
+ scale_value_1 = matcher .get_constant_value (nodes [1 ].input [1 ])
357
+ scale_value_1 = scale_value_1 [0 ] if isinstance (scale_value_1 , list ) else scale_value_1
358
+ scale_value_2 = matcher .get_constant_value (nodes [2 ].input [1 ])
359
+ scale_value_2 = scale_value_2 [0 ] if isinstance (scale_value_2 , list ) else scale_value_2
360
+ if not (
361
+ check_attribute_value (nodes [3 ], "to" , 1 )
362
+ and check_attribute_value (nodes [4 ], "to" , 1 )
363
+ and check_attribute_value (nodes [5 ], "perm" , [0 , 2 , 1 , 3 ])
364
+ and check_attribute_value (nodes [6 ], "perm" , [0 , 2 , 3 , 1 ])
365
+ and check_attribute_value (nodes [15 ], "to" , 10 )
366
+ and check_attribute_value (nodes [17 ], "perm" , [0 , 2 , 1 , 3 ])
367
+ and check_attribute_value (nodes [18 ], "perm" , [0 , 2 , 1 , 3 ])
368
+ and scale_value_1 == scale_value_2
369
+ ):
370
+ return [], [], []
371
+
372
+ nodes_to_add , new_value_infos = _make_efficient_attention_nodes (
373
+ idx ,
374
+ nodes [5 ].input [0 ],
375
+ nodes [6 ].input [0 ],
376
+ nodes [17 ].input [0 ],
377
+ nodes [18 ].output [0 ],
378
+ nodes [34 ].input [0 ],
379
+ nodes [31 ].output [0 ],
380
+ nodes [32 ].output [0 ],
381
+ nodes [35 ].output [0 ],
382
+ "" ,
383
+ False ,
384
+ scale_value_1 ,
385
+ 0.0 ,
386
+ True ,
233
387
)
234
388
return nodes , nodes_to_add , new_value_infos
235
389
236
390
237
391
_PATTERNS = [
238
392
(_PATTERN_0 , _optimize_for_pattern_0 ),
239
393
(_PATTERN_1 , _optimize_for_pattern_1 ),
394
+ (_PATTERN_2 , _optimize_for_pattern_2 ),
395
+ (_PATTERN_3 , _optimize_for_pattern_3 ),
240
396
]
241
397
242
398
0 commit comments