Skip to content

Commit 098eb33

Browse files
authored
refactor sd 1.5 and qwen2-vl and fix (#12590)
1 parent b050368 commit 098eb33

File tree

4 files changed

+23
-58
lines changed

4 files changed

+23
-58
lines changed

python/llm/src/ipex_llm/transformers/models/minicpmv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def siglip_attention_forward(
7575

7676
attn_weights = None
7777
attn_output = scaled_dot_product_attention(
78-
query_states, key_states, value_states,
78+
query_states, key_states.contiguous(), value_states.contiguous(),
7979
attention_mask, False, 1 / math.sqrt(self.head_dim)
8080
)
8181

python/llm/src/ipex_llm/transformers/models/qwen2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,7 @@ def qwen2_attention_forward(
583583
self.layer_idx, None)
584584

585585
attn_weights = None
586-
if query_states.device.type == 'xpu' \
587-
and use_flash_attention(query_states, key_states, attention_mask):
586+
if use_flash_attention(query_states, key_states, attention_mask):
588587
# repeat k/v heads if n_kv_heads < n_heads
589588
key_states = repeat_kv(key_states, self.num_key_value_groups)
590589
value_states = repeat_kv(value_states, self.num_key_value_groups)

python/llm/src/ipex_llm/transformers/models/qwen2_vl.py

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@
4343
import torch
4444

4545
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
46-
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
47-
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, should_use_fuse_rope
46+
from ipex_llm.transformers.models.common import scaled_dot_product_attention
47+
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
48+
from ipex_llm.transformers.models.utils import should_use_fuse_rope
4849
from ipex_llm.transformers.models.utils import use_sdp_non_causal
4950
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
5051
from ipex_llm.utils.common import invalidInputError
@@ -198,7 +199,6 @@ def qwen2_vision_attention_forward(
198199
"unexpected input")
199200

200201
if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
201-
import xe_addons
202202
image_num = len(seq_lens) - 1
203203
image_size = seq_lens[1] - seq_lens[0]
204204
guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size,
@@ -209,7 +209,10 @@ def qwen2_vision_attention_forward(
209209
v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
210210
# q, k, v: [image_num, num_heads, image_size, head_dim]
211211

212-
attn_output = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), None)
212+
attn_output = scaled_dot_product_attention(
213+
q, k.contiguous(), v.contiguous(),
214+
None, False
215+
)
213216
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
214217
attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim)
215218
# attn_output: [seq_length, num_heads, head_dim]
@@ -226,7 +229,10 @@ def qwen2_vision_attention_forward(
226229
tmp_q = q[:, :, start_idx:end_idx, :]
227230
tmp_k = k[:, :, start_idx:end_idx, :]
228231
tmp_v = v[:, :, start_idx:end_idx, :]
229-
attn_output = xe_addons.sdp_non_causal(tmp_q, tmp_k, tmp_v, None)
232+
attn_output = scaled_dot_product_attention(
233+
tmp_q, tmp_k, tmp_v,
234+
None, False
235+
)
230236
attn_output = attn_output.permute(0, 2, 1, 3)
231237
# attn_output: [1, seq_length, num_heads, head_dim]
232238
attn_outputs.append(attn_output)
@@ -293,42 +299,11 @@ def qwen2_vl_attention_forward(
293299
key_states, value_states = past_key_value.update(key_states, value_states,
294300
self.layer_idx, None)
295301

296-
kv_seq_len = key_states.size(2)
297-
if attention_mask is not None: # no matter the length, we just slice it
298-
causal_mask = attention_mask[:, :, :, :kv_seq_len]
299-
300302
attn_weights = None
301-
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
302-
import xe_addons
303-
if isinstance(past_key_value, DynamicFp8Cache):
304-
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask)
305-
else:
306-
attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask)
307-
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
308-
import xe_addons
309-
if isinstance(past_key_value, DynamicFp8Cache):
310-
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
311-
value_states, causal_mask)
312-
else:
313-
attn_output = xe_addons.sdp_causal(query_states, key_states,
314-
value_states, causal_mask)
315-
else:
316-
if isinstance(past_key_value, DynamicFp8Cache):
317-
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
318-
query_states.dtype)
319-
# repeat k/v heads if n_kv_heads < n_heads
320-
key_states = repeat_kv(key_states, self.num_key_value_groups)
321-
value_states = repeat_kv(value_states, self.num_key_value_groups)
322-
323-
attn_weights = torch.matmul(query_states,
324-
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
325-
326-
if causal_mask is not None:
327-
attn_weights = attn_weights + causal_mask
328-
329-
# upcast attention to fp32
330-
attn_weights = attention_softmax(attn_weights)
331-
attn_output = torch.matmul(attn_weights, value_states)
303+
attn_output = scaled_dot_product_attention(
304+
query_states, key_states, value_states,
305+
attention_mask, q_len == key_states.size(2)
306+
)
332307

333308
attn_output = attn_output.transpose(1, 2).contiguous()
334309
attn_output = attn_output.reshape(bsz, q_len, -1)

python/llm/src/ipex_llm/transformers/models/sd.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
from typing import Optional
3838

3939
from ipex_llm.transformers.utils import get_xpu_device_type
40-
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
41-
from ipex_llm.transformers.models.utils import use_sdp_non_causal
40+
from ipex_llm.transformers.models.common import padding_qkv_hd
41+
from ipex_llm.transformers.models.common import scaled_dot_product_attention
4242
from diffusers.models.attention_processor import Attention
4343

4444

@@ -110,19 +110,10 @@ def __call__(
110110
if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
111111
# padding head_dim 40 to 64
112112
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
113-
114-
if use_sdp_non_causal(query.size(-1), query.device, query.dtype):
115-
import xe_addons
116-
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
117-
value.contiguous(), attention_mask)
118-
else:
119-
scale = 1 / math.sqrt(head_dim)
120-
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
121-
if attention_mask is not None:
122-
attn_weights = attn_weights + attention_mask
123-
attn_weights = attention_softmax(attn_weights)
124-
hidden_states = torch.matmul(attn_weights, value)
125-
113+
hidden_states = scaled_dot_product_attention(
114+
query, key.contiguous(), value.contiguous(),
115+
attention_mask, False, 1 / math.sqrt(head_dim)
116+
)
126117
hidden_states = hidden_states[:, :, :, :head_dim]
127118
else:
128119
hidden_states = torch.nn.functional.scaled_dot_product_attention(

0 commit comments

Comments
 (0)