43
43
import torch
44
44
45
45
from ipex_llm .transformers .models .common import merge_qkv_base , attention_softmax
46
- from ipex_llm .transformers .models .utils import use_quantize_kv_cache , restore_fp8_kv_cache
47
- from ipex_llm .transformers .models .utils import use_sdp , use_sdp_causal , should_use_fuse_rope
46
+ from ipex_llm .transformers .models .common import scaled_dot_product_attention
47
+ from ipex_llm .transformers .models .utils import use_quantize_kv_cache
48
+ from ipex_llm .transformers .models .utils import should_use_fuse_rope
48
49
from ipex_llm .transformers .models .utils import use_sdp_non_causal
49
50
from ipex_llm .transformers .kv import DynamicFp8Cache , DynamicNormalCache
50
51
from ipex_llm .utils .common import invalidInputError
@@ -198,7 +199,6 @@ def qwen2_vision_attention_forward(
198
199
"unexpected input" )
199
200
200
201
if use_sdp_non_causal (self .head_dim , q .device , q .dtype ):
201
- import xe_addons
202
202
image_num = len (seq_lens ) - 1
203
203
image_size = seq_lens [1 ] - seq_lens [0 ]
204
204
guessed_seq_lens = torch .arange (0 , (image_num + 1 ) * image_size , image_size ,
@@ -209,7 +209,10 @@ def qwen2_vision_attention_forward(
209
209
v = v .view (image_num , image_size , self .num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 )
210
210
# q, k, v: [image_num, num_heads, image_size, head_dim]
211
211
212
- attn_output = xe_addons .sdp_non_causal (q , k .contiguous (), v .contiguous (), None )
212
+ attn_output = scaled_dot_product_attention (
213
+ q , k .contiguous (), v .contiguous (),
214
+ None , False
215
+ )
213
216
attn_output = attn_output .permute (0 , 2 , 1 , 3 ).contiguous ()
214
217
attn_output = attn_output .view (seq_length , self .num_heads , self .head_dim )
215
218
# attn_output: [seq_length, num_heads, head_dim]
@@ -226,7 +229,10 @@ def qwen2_vision_attention_forward(
226
229
tmp_q = q [:, :, start_idx :end_idx , :]
227
230
tmp_k = k [:, :, start_idx :end_idx , :]
228
231
tmp_v = v [:, :, start_idx :end_idx , :]
229
- attn_output = xe_addons .sdp_non_causal (tmp_q , tmp_k , tmp_v , None )
232
+ attn_output = scaled_dot_product_attention (
233
+ tmp_q , tmp_k , tmp_v ,
234
+ None , False
235
+ )
230
236
attn_output = attn_output .permute (0 , 2 , 1 , 3 )
231
237
# attn_output: [1, seq_length, num_heads, head_dim]
232
238
attn_outputs .append (attn_output )
@@ -293,42 +299,11 @@ def qwen2_vl_attention_forward(
293
299
key_states , value_states = past_key_value .update (key_states , value_states ,
294
300
self .layer_idx , None )
295
301
296
- kv_seq_len = key_states .size (2 )
297
- if attention_mask is not None : # no matter the length, we just slice it
298
- causal_mask = attention_mask [:, :, :, :kv_seq_len ]
299
-
300
302
attn_weights = None
301
- if use_sdp (q_len , kv_seq_len , self .head_dim , query_states ):
302
- import xe_addons
303
- if isinstance (past_key_value , DynamicFp8Cache ):
304
- attn_output = xe_addons .sdp_fp8 (query_states , key_states , value_states , causal_mask )
305
- else :
306
- attn_output = xe_addons .sdp (query_states , key_states , value_states , causal_mask )
307
- elif use_sdp_causal (q_len , kv_seq_len , self .head_dim , query_states , self .training ):
308
- import xe_addons
309
- if isinstance (past_key_value , DynamicFp8Cache ):
310
- attn_output = xe_addons .sdp_fp8_causal (query_states , key_states ,
311
- value_states , causal_mask )
312
- else :
313
- attn_output = xe_addons .sdp_causal (query_states , key_states ,
314
- value_states , causal_mask )
315
- else :
316
- if isinstance (past_key_value , DynamicFp8Cache ):
317
- key_states , value_states = restore_fp8_kv_cache (key_states , value_states ,
318
- query_states .dtype )
319
- # repeat k/v heads if n_kv_heads < n_heads
320
- key_states = repeat_kv (key_states , self .num_key_value_groups )
321
- value_states = repeat_kv (value_states , self .num_key_value_groups )
322
-
323
- attn_weights = torch .matmul (query_states ,
324
- key_states .transpose (2 , 3 )) / math .sqrt (self .head_dim )
325
-
326
- if causal_mask is not None :
327
- attn_weights = attn_weights + causal_mask
328
-
329
- # upcast attention to fp32
330
- attn_weights = attention_softmax (attn_weights )
331
- attn_output = torch .matmul (attn_weights , value_states )
303
+ attn_output = scaled_dot_product_attention (
304
+ query_states , key_states , value_states ,
305
+ attention_mask , q_len == key_states .size (2 )
306
+ )
332
307
333
308
attn_output = attn_output .transpose (1 , 2 ).contiguous ()
334
309
attn_output = attn_output .reshape (bsz , q_len , - 1 )
0 commit comments