Skip to content

Commit c85ebe4

Browse files
fix build (#3279)
Co-authored-by: Xu Han <[email protected]>
1 parent 9f6178e commit c85ebe4

File tree

1 file changed

+10
-20
lines changed

1 file changed

+10
-20
lines changed

csrc/cpu/aten/PagedAttention.cpp

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "PagedAttention.h"
22
#include <torch/all.h>
33
#include <torch/csrc/autograd/function.h>
4+
#include "csrc/utils/CustomOperatorRegistration.h"
45

56
namespace torch_ipex {
67
namespace cpu {
@@ -84,28 +85,17 @@ void flash_attn_varlen_cpu(
8485
namespace {
8586

8687
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
87-
m.def(
88-
"single_query_cached_kv_attention(Tensor (a!)out, Tensor query, Tensor key_cache, Tensor value_cache,\
89-
Tensor head_mapping, float scale, Tensor block_tables, Tensor context_lens, int block_size, int max_context_len,\
90-
Tensor? alibi_slopes)-> ()");
91-
m.impl(
88+
IPEX_OP_REGISTER_DISPATCH(
9289
"single_query_cached_kv_attention",
93-
c10::DispatchKey::CPU,
94-
torch_ipex::cpu::single_query_cached_kv_attention_forward_cpu);
95-
m.def(
96-
"reshape_and_cache(Tensor key, Tensor value, Tensor (a!)key_cache, Tensor (a!)value_cache, Tensor slot_mapping)-> ()");
97-
m.impl(
90+
torch_ipex::cpu::single_query_cached_kv_attention_forward_cpu,
91+
c10::DispatchKey::CPU);
92+
IPEX_OP_REGISTER_DISPATCH(
9893
"reshape_and_cache",
99-
c10::DispatchKey::CPU,
100-
torch_ipex::cpu::reshape_and_cache_cpu);
101-
m.def(
102-
"flash_attn_varlen_func(Tensor (a!)out, Tensor (a!)query, Tensor (a!)key, Tensor (a!)value, Tensor(a!) cu_seqlens_q,\
103-
Tensor(a!) cu_seqlens_kv, int max_seqlen_q, int max_seqlen_kv, float softmax_scale, bool is_causal, Tensor(a!) block_table, \
104-
Tensor? alibi_slopes)-> ()");
105-
106-
m.impl(
94+
torch_ipex::cpu::reshape_and_cache_cpu,
95+
c10::DispatchKey::CPU);
96+
IPEX_OP_REGISTER_DISPATCH(
10797
"flash_attn_varlen_func",
108-
c10::DispatchKey::CPU,
109-
torch_ipex::cpu::flash_attn_varlen_cpu);
98+
torch_ipex::cpu::flash_attn_varlen_cpu,
99+
c10::DispatchKey::CPU);
110100
}
111101
} // namespace

0 commit comments

Comments
 (0)