@@ -76,7 +76,7 @@ def mha_ref(q, k, v, scale, is_causal, window_size, softcap):
76
76
77
77
class TestFlashAttnVarLen (TestCase ):
78
78
79
- @torch .inference_mode ()
79
+ @torch .no_grad ()
80
80
def _test_flash_attn_varlen (
81
81
self ,
82
82
num_heads : int ,
@@ -86,6 +86,7 @@ def _test_flash_attn_varlen(
86
86
is_causal : bool ,
87
87
dtype : torch .dtype ,
88
88
softcap : float ,
89
+ is_compile : bool ,
89
90
) -> None :
90
91
random .seed (0 )
91
92
torch .manual_seed (0 )
@@ -163,7 +164,11 @@ def _test_flash_attn_varlen(
163
164
output_ref [cu_seq_lens_q [i ] : cu_seq_lens_q [i + 1 ]] = output_i
164
165
165
166
output = torch .empty_like (query )
166
- ipex .llm .modules .PagedAttention .flash_attn_varlen_func (
167
+ if is_compile :
168
+ f_c = torch .compile (ipex .llm .modules .PagedAttention .flash_attn_varlen_func )
169
+ else :
170
+ f_c = ipex .llm .modules .PagedAttention .flash_attn_varlen_func
171
+ f_c (
167
172
output ,
168
173
query ,
169
174
k_cache ,
@@ -185,7 +190,7 @@ def _test_flash_attn_varlen(
185
190
output_ref , output , atol = 1e-6 if dtype == torch .float else 5e-2
186
191
)
187
192
188
- @torch .inference_mode ()
193
+ @torch .no_grad ()
189
194
def _test_flash_attn_varlen_fp8 (
190
195
self ,
191
196
num_heads : int ,
@@ -195,6 +200,7 @@ def _test_flash_attn_varlen_fp8(
195
200
is_causal : bool ,
196
201
dtype : torch .dtype ,
197
202
softcap : float ,
203
+ is_compile : bool ,
198
204
) -> None :
199
205
random .seed (0 )
200
206
torch .manual_seed (0 )
@@ -255,6 +261,7 @@ def _test_flash_attn_varlen_fp8(
255
261
scale = float (1.0 / (head_size ** 0.5 ))
256
262
257
263
output_ref = torch .empty_like (query )
264
+
258
265
ipex .llm .modules .PagedAttention .flash_attn_varlen_func (
259
266
output_ref ,
260
267
query ,
@@ -272,9 +279,12 @@ def _test_flash_attn_varlen_fp8(
272
279
window_size [1 ],
273
280
softcap = softcap ,
274
281
)
275
-
282
+ if is_compile :
283
+ f_c = torch .compile (ipex .llm .modules .PagedAttention .flash_attn_varlen_func )
284
+ else :
285
+ f_c = ipex .llm .modules .PagedAttention .flash_attn_varlen_func
276
286
output = torch .empty_like (query )
277
- ipex . llm . modules . PagedAttention . flash_attn_varlen_func (
287
+ f_c (
278
288
output ,
279
289
query ,
280
290
k_cache .to (torch .float8_e5m2 ),
@@ -297,6 +307,9 @@ def _test_flash_attn_varlen_fp8(
297
307
)
298
308
299
309
def test_flash_attn_varlen (self ):
310
+ COMPILE_TEST = (
311
+ 1 # test torch.compile function for once, avoiding recompile in CI
312
+ )
300
313
for (
301
314
num_heads ,
302
315
num_queries_per_kv ,
@@ -314,17 +327,24 @@ def test_flash_attn_varlen(self):
314
327
DTYPES ,
315
328
SOFTCAP ,
316
329
):
317
- self ._test_flash_attn_varlen (
318
- num_heads ,
319
- num_queries_per_kv ,
320
- head_size ,
321
- window_size ,
322
- is_causal ,
323
- dtype ,
324
- softcap ,
325
- )
330
+ COMPILE = [True , False ] if COMPILE_TEST == 1 else [False ]
331
+ COMPILE_TEST = COMPILE_TEST - 1
332
+ for is_compile in COMPILE :
333
+ self ._test_flash_attn_varlen (
334
+ num_heads ,
335
+ num_queries_per_kv ,
336
+ head_size ,
337
+ window_size ,
338
+ is_causal ,
339
+ dtype ,
340
+ softcap ,
341
+ is_compile ,
342
+ )
326
343
327
344
def test_flash_attn_varlen_fp8 (self ):
345
+ COMPILE_TEST = (
346
+ 1 # test torch.compile function for once, avoiding recompile in CI
347
+ )
328
348
for (
329
349
num_heads ,
330
350
num_queries_per_kv ,
@@ -342,15 +362,19 @@ def test_flash_attn_varlen_fp8(self):
342
362
[torch .float , torch .bfloat16 ],
343
363
SOFTCAP ,
344
364
):
345
- self ._test_flash_attn_varlen_fp8 (
346
- num_heads ,
347
- num_queries_per_kv ,
348
- head_size ,
349
- window_size ,
350
- is_causal ,
351
- dtype ,
352
- softcap ,
353
- )
365
+ COMPILE = [True , False ] if COMPILE_TEST == 1 else [False ]
366
+ COMPILE_TEST = COMPILE_TEST - 1
367
+ for is_compile in COMPILE :
368
+ self ._test_flash_attn_varlen_fp8 (
369
+ num_heads ,
370
+ num_queries_per_kv ,
371
+ head_size ,
372
+ window_size ,
373
+ is_causal ,
374
+ dtype ,
375
+ softcap ,
376
+ is_compile ,
377
+ )
354
378
355
379
356
380
if __name__ == "__main__" :
0 commit comments