|
40 | 40 | import torch
|
41 | 41 | import warnings
|
42 | 42 |
|
43 |
| -from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax |
44 |
| -from ipex_llm.transformers.models.utils import ( |
45 |
| - use_quantize_kv_cache, restore_fp8_kv_cache, |
46 |
| - should_use_fuse_rope, use_sdp, use_sdp_causal |
47 |
| -) |
| 43 | +from ipex_llm.transformers.models.common import merge_qkv_base |
| 44 | +from ipex_llm.transformers.models.common import scaled_dot_product_attention |
| 45 | +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, should_use_fuse_rope |
48 | 46 | from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
49 | 47 | from ipex_llm.utils.common.log4Error import invalidInputError
|
50 | 48 |
|
51 | 49 | from typing import Optional, Tuple, List
|
52 | 50 | from transformers.cache_utils import Cache
|
53 |
| -from transformers.models.starcoder2.modeling_starcoder2 import repeat_kv, apply_rotary_pos_emb |
| 51 | +from transformers.models.starcoder2.modeling_starcoder2 import apply_rotary_pos_emb |
54 | 52 | from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model, Starcoder2Attention
|
55 | 53 |
|
56 | 54 |
|
@@ -103,41 +101,11 @@ def attention_forward(
|
103 | 101 | self.layer_idx, None)
|
104 | 102 |
|
105 | 103 | # IPEX-LLM OPT: sdp
|
106 |
| - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): |
107 |
| - import xe_addons |
108 |
| - if isinstance(past_key_value, DynamicFp8Cache): |
109 |
| - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, |
110 |
| - attention_mask) |
111 |
| - else: |
112 |
| - attn_output = xe_addons.sdp(query_states, key_states, value_states, |
113 |
| - attention_mask) |
114 |
| - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): |
115 |
| - import xe_addons |
116 |
| - if isinstance(past_key_value, DynamicFp8Cache): |
117 |
| - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, |
118 |
| - value_states, attention_mask) |
119 |
| - else: |
120 |
| - attn_output = xe_addons.sdp_causal(query_states, key_states, |
121 |
| - value_states, attention_mask) |
122 |
| - else: |
123 |
| - if isinstance(past_key_value, DynamicFp8Cache): |
124 |
| - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, |
125 |
| - query_states.dtype) |
126 |
| - # repeat k/v heads if n_kv_heads < n_heads |
127 |
| - key_states = repeat_kv(key_states, self.num_key_value_groups) |
128 |
| - value_states = repeat_kv(value_states, self.num_key_value_groups) |
129 |
| - |
130 |
| - attn_weights = torch.matmul(query_states, |
131 |
| - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
132 |
| - |
133 |
| - if attention_mask is not None: |
134 |
| - attn_weights = attn_weights + attention_mask |
135 |
| - |
136 |
| - # upcast attention to fp32 |
137 |
| - attn_weights = attention_softmax(attn_weights) |
138 |
| - attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, |
139 |
| - training=self.training) |
140 |
| - attn_output = torch.matmul(attn_weights, value_states) |
| 104 | + attn_weights = None |
| 105 | + attn_output = scaled_dot_product_attention( |
| 106 | + query_states, key_states, value_states, |
| 107 | + attention_mask, q_len == kv_seq_len |
| 108 | + ) |
141 | 109 |
|
142 | 110 | attn_output = attn_output.transpose(1, 2).contiguous()
|
143 | 111 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
0 commit comments