|
1 | 1 | #include "PagedAttention.h"
|
2 | 2 | #include <torch/all.h>
|
3 | 3 | #include <torch/csrc/autograd/function.h>
|
| 4 | +#include "csrc/utils/CustomOperatorRegistration.h" |
4 | 5 |
|
5 | 6 | namespace torch_ipex {
|
6 | 7 | namespace cpu {
|
@@ -84,28 +85,17 @@ void flash_attn_varlen_cpu(
|
84 | 85 | namespace {
|
85 | 86 |
|
86 | 87 | TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
|
87 |
| - m.def( |
88 |
| - "single_query_cached_kv_attention(Tensor (a!)out, Tensor query, Tensor key_cache, Tensor value_cache,\ |
89 |
| - Tensor head_mapping, float scale, Tensor block_tables, Tensor context_lens, int block_size, int max_context_len,\ |
90 |
| - Tensor? alibi_slopes)-> ()"); |
91 |
| - m.impl( |
| 88 | + IPEX_OP_REGISTER_DISPATCH( |
92 | 89 | "single_query_cached_kv_attention",
|
93 |
| - c10::DispatchKey::CPU, |
94 |
| - torch_ipex::cpu::single_query_cached_kv_attention_forward_cpu); |
95 |
| - m.def( |
96 |
| - "reshape_and_cache(Tensor key, Tensor value, Tensor (a!)key_cache, Tensor (a!)value_cache, Tensor slot_mapping)-> ()"); |
97 |
| - m.impl( |
| 90 | + torch_ipex::cpu::single_query_cached_kv_attention_forward_cpu, |
| 91 | + c10::DispatchKey::CPU); |
| 92 | + IPEX_OP_REGISTER_DISPATCH( |
98 | 93 | "reshape_and_cache",
|
99 |
| - c10::DispatchKey::CPU, |
100 |
| - torch_ipex::cpu::reshape_and_cache_cpu); |
101 |
| - m.def( |
102 |
| - "flash_attn_varlen_func(Tensor (a!)out, Tensor (a!)query, Tensor (a!)key, Tensor (a!)value, Tensor(a!) cu_seqlens_q,\ |
103 |
| - Tensor(a!) cu_seqlens_kv, int max_seqlen_q, int max_seqlen_kv, float softmax_scale, bool is_causal, Tensor(a!) block_table, \ |
104 |
| - Tensor? alibi_slopes)-> ()"); |
105 |
| - |
106 |
| - m.impl( |
| 94 | + torch_ipex::cpu::reshape_and_cache_cpu, |
| 95 | + c10::DispatchKey::CPU); |
| 96 | + IPEX_OP_REGISTER_DISPATCH( |
107 | 97 | "flash_attn_varlen_func",
|
108 |
| - c10::DispatchKey::CPU, |
109 |
| - torch_ipex::cpu::flash_attn_varlen_cpu); |
| 98 | + torch_ipex::cpu::flash_attn_varlen_cpu, |
| 99 | + c10::DispatchKey::CPU); |
110 | 100 | }
|
111 | 101 | } // namespace
|
0 commit comments