Skip to content

Commit f736e1a

Browse files
authored
Support softcap in PagedAttention APIs (#3599)
1 parent 8a5b92c commit f736e1a

File tree

7 files changed

+158
-41
lines changed

7 files changed

+158
-41
lines changed

csrc/cpu/aten/PagedAttention.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ void single_query_cached_kv_attention_forward_cpu(
2626
int64_t max_context_len,
2727
const c10::optional<at::Tensor>& alibi_slopes,
2828
const double k_scale,
29-
const double v_scale) {
29+
const double v_scale,
30+
const double softcap) {
3031
return single_query_cached_kv_attention_kernel_stub(
3132
kCPU,
3233
out,
@@ -41,7 +42,8 @@ void single_query_cached_kv_attention_forward_cpu(
4142
max_context_len,
4243
alibi_slopes,
4344
k_scale,
44-
v_scale);
45+
v_scale,
46+
softcap);
4547
}
4648

4749
void reshape_and_cache_cpu(
@@ -70,7 +72,8 @@ void flash_attn_varlen_cpu(
7072
at::Tensor& block_table,
7173
const c10::optional<at::Tensor>& alibi_slopes,
7274
const double k_scale,
73-
const double v_scale) {
75+
const double v_scale,
76+
const double softcap) {
7477
return flash_attn_var_len_kernel_stub(
7578
kCPU,
7679
out,
@@ -86,7 +89,8 @@ void flash_attn_varlen_cpu(
8689
block_table,
8790
alibi_slopes,
8891
k_scale,
89-
v_scale);
92+
v_scale,
93+
softcap);
9094
}
9195

9296
} // namespace cpu

csrc/cpu/aten/PagedAttention.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ void single_query_cached_kv_attention(
2121
int64_t max_context_len,
2222
const c10::optional<at::Tensor>& alibi_slopes,
2323
const double k_scale,
24-
const double v_scale);
24+
const double v_scale,
25+
const double softcap);
2526
}
2627

2728
void reshape_and_cache(
@@ -47,7 +48,8 @@ void flash_attn_varlen(
4748
at::Tensor& block_table,
4849
const c10::optional<at::Tensor>& alibi_slopes,
4950
const double k_scale,
50-
const double v_scale);
51+
const double v_scale,
52+
const double softcap);
5153

5254
using single_query_cached_kv_attention_fn = void (*)(
5355
at::Tensor& out, // [num_seqs, num_heads, head_size]
@@ -62,7 +64,8 @@ using single_query_cached_kv_attention_fn = void (*)(
6264
int64_t max_context_len,
6365
const c10::optional<at::Tensor>& alibi_slopes,
6466
const double k_scale,
65-
const double v_scale);
67+
const double v_scale,
68+
const double softcap);
6669

6770
using reshape_and_cache_fn = void (*)(
6871
at::Tensor& key,
@@ -87,7 +90,8 @@ using flash_attn_var_len_fn = void (*)(
8790
at::Tensor& block_table,
8891
const c10::optional<at::Tensor>& alibi_slopes,
8992
const double k_scale,
90-
const double v_scale);
93+
const double v_scale,
94+
const double softcap);
9195

9296
IPEX_DECLARE_DISPATCH(
9397
single_query_cached_kv_attention_fn,

0 commit comments

Comments
 (0)