Skip to content

Commit b050368

Browse files
authored
refactor yuan2 and starcoder2 and fix (#12589)
1 parent 6ea8033 commit b050368

File tree

6 files changed

+28
-83
lines changed

6 files changed

+28
-83
lines changed

python/llm/src/ipex_llm/transformers/models/llama32.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def llama_attention_forward(
234234
attn_weights = None
235235
attn_output = scaled_dot_product_attention(
236236
query_states, key_states, value_states,
237-
attention_mask, q_len == key_states.size(2), math.sqrt(self.head_dim)
237+
attention_mask, q_len == key_states.size(2)
238238
)
239239

240240
attn_output = attn_output.transpose(1, 2).contiguous()

python/llm/src/ipex_llm/transformers/models/minicpm.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,13 @@
3838

3939
import torch
4040
import warnings
41-
import torch.nn as nn
4241
from typing import Optional, Tuple, Union, List
4342
import math
4443
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
45-
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache
46-
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, get_compresskv_attn_mask
44+
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
4745
from ipex_llm.transformers.models.utils import should_use_compresskv, should_use_fuse_rope
48-
from ipex_llm.transformers.models.llama import repeat_kv
4946
from ipex_llm.transformers.models.common import merge_qkv_base
47+
from ipex_llm.transformers.models.common import scaled_dot_product_attention
5048
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \
5149
DynamicCompressCache, DynamicCompressFp8Cache
5250
from transformers.cache_utils import Cache
@@ -127,11 +125,10 @@ def minicpm_attention_forward(
127125
key_states, value_states = past_key_value.update(key_states, value_states,
128126
self.layer_idx, None)
129127

130-
from ipex_llm.transformers.models.common import scaled_dot_product_attention
131128
attn_weights = None
132129
attn_output = scaled_dot_product_attention(
133130
query_states, key_states, value_states,
134-
attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim)
131+
attention_mask, q_len == kv_seq_len
135132
)
136133

137134
attn_output = attn_output.transpose(1, 2).contiguous()

python/llm/src/ipex_llm/transformers/models/minicpmv.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torch.nn.functional import linear
2929
from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd
3030
from ipex_llm.transformers.models.common import attention_softmax
31+
from ipex_llm.transformers.models.common import scaled_dot_product_attention
3132
from transformers import AutoProcessor, TextIteratorStreamer
3233
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
3334

@@ -72,10 +73,11 @@ def siglip_attention_forward(
7273
72, 80
7374
)
7475

75-
from ipex_llm.transformers.models.common import scaled_dot_product_attention
7676
attn_weights = None
77-
attn_output = scaled_dot_product_attention(query_states, key_states, value_states,
78-
attention_mask, False, math.sqrt(self.head_dim))
77+
attn_output = scaled_dot_product_attention(
78+
query_states, key_states, value_states,
79+
attention_mask, False, 1 / math.sqrt(self.head_dim)
80+
)
7981

8082
attn_output = attn_output[:, :, :, :self.head_dim]
8183

python/llm/src/ipex_llm/transformers/models/qwen2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ def qwen2_attention_forward(
595595
else:
596596
attn_output = scaled_dot_product_attention(
597597
query_states, key_states, value_states,
598-
attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim)
598+
attention_mask, q_len == kv_seq_len
599599
)
600600

601601
attn_output = attn_output.transpose(1, 2).contiguous()

python/llm/src/ipex_llm/transformers/models/starcoder2.py

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,15 @@
4040
import torch
4141
import warnings
4242

43-
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
44-
from ipex_llm.transformers.models.utils import (
45-
use_quantize_kv_cache, restore_fp8_kv_cache,
46-
should_use_fuse_rope, use_sdp, use_sdp_causal
47-
)
43+
from ipex_llm.transformers.models.common import merge_qkv_base
44+
from ipex_llm.transformers.models.common import scaled_dot_product_attention
45+
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, should_use_fuse_rope
4846
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
4947
from ipex_llm.utils.common.log4Error import invalidInputError
5048

5149
from typing import Optional, Tuple, List
5250
from transformers.cache_utils import Cache
53-
from transformers.models.starcoder2.modeling_starcoder2 import repeat_kv, apply_rotary_pos_emb
51+
from transformers.models.starcoder2.modeling_starcoder2 import apply_rotary_pos_emb
5452
from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model, Starcoder2Attention
5553

5654

@@ -103,41 +101,11 @@ def attention_forward(
103101
self.layer_idx, None)
104102

105103
# IPEX-LLM OPT: sdp
106-
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
107-
import xe_addons
108-
if isinstance(past_key_value, DynamicFp8Cache):
109-
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
110-
attention_mask)
111-
else:
112-
attn_output = xe_addons.sdp(query_states, key_states, value_states,
113-
attention_mask)
114-
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
115-
import xe_addons
116-
if isinstance(past_key_value, DynamicFp8Cache):
117-
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
118-
value_states, attention_mask)
119-
else:
120-
attn_output = xe_addons.sdp_causal(query_states, key_states,
121-
value_states, attention_mask)
122-
else:
123-
if isinstance(past_key_value, DynamicFp8Cache):
124-
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
125-
query_states.dtype)
126-
# repeat k/v heads if n_kv_heads < n_heads
127-
key_states = repeat_kv(key_states, self.num_key_value_groups)
128-
value_states = repeat_kv(value_states, self.num_key_value_groups)
129-
130-
attn_weights = torch.matmul(query_states,
131-
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
132-
133-
if attention_mask is not None:
134-
attn_weights = attn_weights + attention_mask
135-
136-
# upcast attention to fp32
137-
attn_weights = attention_softmax(attn_weights)
138-
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
139-
training=self.training)
140-
attn_output = torch.matmul(attn_weights, value_states)
104+
attn_weights = None
105+
attn_output = scaled_dot_product_attention(
106+
query_states, key_states, value_states,
107+
attention_mask, q_len == kv_seq_len
108+
)
141109

142110
attn_output = attn_output.transpose(1, 2).contiguous()
143111
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

python/llm/src/ipex_llm/transformers/models/yuan.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@
2626
import torch
2727

2828
from ipex_llm.utils.common import invalidInputError
29-
from ipex_llm.transformers.models.common import attention_softmax
29+
from ipex_llm.transformers.models.common import scaled_dot_product_attention
3030
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
3131
mlp_fusion_check, fp16_fusion_check
32-
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
32+
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
3333
from ipex_llm.transformers.models.utils import SILU, update_past_key_value
34-
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
34+
from ipex_llm.transformers.models.utils import should_use_fuse_rope
3535

3636

3737
def merge_qk(module: torch.nn.Module):
@@ -214,34 +214,12 @@ def yuan_attention_forward(
214214
)
215215
past_key_value = (key_states, value_states, before_hidden_states) if use_cache else None
216216

217-
# IPEX-LLM OPT: sdp
218-
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
219-
import xe_addons
220-
if use_quantize_kv:
221-
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
222-
attention_mask)
223-
else:
224-
attn_output = xe_addons.sdp(query_states, key_states, value_states,
225-
attention_mask)
226-
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
227-
import xe_addons
228-
if use_quantize_kv:
229-
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
230-
value_states, attention_mask)
231-
else:
232-
attn_output = xe_addons.sdp_causal(query_states, key_states,
233-
value_states, attention_mask)
234-
else:
235-
if use_quantize_kv:
236-
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
237-
query_states.dtype)
238-
attn_weights = torch.matmul(query_states,
239-
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
240-
if attention_mask is not None:
241-
attn_weights = attn_weights + attention_mask
242-
# upcast attention to fp32
243-
attn_weights = attention_softmax(attn_weights)
244-
attn_output = torch.matmul(attn_weights, value_states)
217+
# IPEX-LLM OPT: sdpa
218+
attn_weights = None
219+
attn_output = scaled_dot_product_attention(
220+
query_states, key_states, value_states,
221+
attention_mask, q_len == kv_seq_len
222+
)
245223

246224
attn_output = attn_output.transpose(1, 2)
247225
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

0 commit comments

Comments
 (0)