Skip to content

Commit 88dbb2c

Browse files
authored
Add sgmv support for punica; Support punica on inductor (#3692)
* init sgmv * Add frontend API; Improve code style * update varlen_attention API * add meta punica func * fix ops name * support torch.compile * fix flake * fix clang
1 parent 1f0f139 commit 88dbb2c

File tree

10 files changed

+869
-40
lines changed

10 files changed

+869
-40
lines changed

csrc/cpu/aten/Punica.cpp

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,46 +7,97 @@ namespace torch_ipex {
77
namespace cpu {
88

99
IPEX_DEFINE_DISPATCH(punica_bgmv_shrink_kernel_stub);
10+
IPEX_DEFINE_DISPATCH(punica_sgmv_shrink_kernel_stub);
1011
IPEX_DEFINE_DISPATCH(punica_bgmv_expand_kernel_stub);
12+
IPEX_DEFINE_DISPATCH(punica_sgmv_expand_kernel_stub);
1113
IPEX_DEFINE_DISPATCH(punica_bgmv_expand_slice_kernel_stub);
14+
IPEX_DEFINE_DISPATCH(punica_sgmv_expand_slice_kernel_stub);
1215

13-
void punica_bgmv_shrink_forward_cpu(
16+
at::Tensor punica_bgmv_shrink_forward_cpu(
1417
at::Tensor& out,
1518
at::Tensor& input,
1619
at::Tensor& weights,
1720
at::Tensor& indicies,
1821
const double scale) {
19-
return punica_bgmv_shrink_kernel_stub(
20-
kCPU, out, input, weights, indicies, scale);
22+
punica_bgmv_shrink_kernel_stub(kCPU, out, input, weights, indicies, scale);
23+
return out;
2124
}
2225

23-
void punica_bgmv_expand_forward_cpu(
26+
at::Tensor punica_sgmv_shrink_forward_cpu(
27+
at::Tensor& out,
28+
at::Tensor& input,
29+
at::Tensor& weights,
30+
at::Tensor& indicies,
31+
at::Tensor& seq_lens,
32+
const double scale) {
33+
punica_sgmv_shrink_kernel_stub(
34+
kCPU, out, input, weights, indicies, seq_lens, scale);
35+
return out;
36+
}
37+
38+
at::Tensor punica_bgmv_expand_forward_cpu(
2439
at::Tensor& out,
2540
at::Tensor& input,
2641
at::Tensor& weights,
2742
at::Tensor& indicies,
2843
bool add_inputs) {
29-
return punica_bgmv_expand_kernel_stub(
44+
punica_bgmv_expand_kernel_stub(
3045
kCPU, out, input, weights, indicies, add_inputs);
46+
return out;
47+
}
48+
49+
at::Tensor punica_sgmv_expand_forward_cpu(
50+
at::Tensor& out,
51+
at::Tensor& input,
52+
at::Tensor& weights,
53+
at::Tensor& indicies,
54+
at::Tensor& seq_lens,
55+
bool add_inputs) {
56+
punica_sgmv_expand_kernel_stub(
57+
kCPU, out, input, weights, indicies, seq_lens, add_inputs);
58+
return out;
59+
}
60+
61+
at::Tensor punica_bgmv_expand_slice_forward_cpu(
62+
at::Tensor& out,
63+
at::Tensor& input,
64+
at::Tensor& weights,
65+
at::Tensor& indicies,
66+
int64_t slice_offset,
67+
int64_t slice_size,
68+
bool add_inputs) {
69+
punica_bgmv_expand_slice_kernel_stub(
70+
kCPU,
71+
out,
72+
input,
73+
weights,
74+
indicies,
75+
slice_offset,
76+
slice_size,
77+
add_inputs);
78+
return out;
3179
}
3280

33-
void punica_bgmv_expand_slice_forward_cpu(
81+
at::Tensor punica_sgmv_expand_slice_forward_cpu(
3482
at::Tensor& out,
3583
at::Tensor& input,
3684
at::Tensor& weights,
3785
at::Tensor& indicies,
86+
at::Tensor& seq_lens,
3887
int64_t slice_offset,
3988
int64_t slice_size,
4089
bool add_inputs) {
41-
return punica_bgmv_expand_slice_kernel_stub(
90+
punica_sgmv_expand_slice_kernel_stub(
4291
kCPU,
4392
out,
4493
input,
4594
weights,
4695
indicies,
96+
seq_lens,
4797
slice_offset,
4898
slice_size,
4999
add_inputs);
100+
return out;
50101
}
51102

52103
} // namespace cpu
@@ -61,18 +112,39 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
61112
c10::DispatchKey::CPU);
62113
}
63114

115+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
116+
IPEX_OP_REGISTER_DISPATCH(
117+
"punica_sgmv_shrink",
118+
torch_ipex::cpu::punica_sgmv_shrink_forward_cpu,
119+
c10::DispatchKey::CPU);
120+
}
121+
64122
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
65123
IPEX_OP_REGISTER_DISPATCH(
66124
"punica_bgmv_expand",
67125
torch_ipex::cpu::punica_bgmv_expand_forward_cpu,
68126
c10::DispatchKey::CPU);
69127
}
70128

129+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
130+
IPEX_OP_REGISTER_DISPATCH(
131+
"punica_sgmv_expand",
132+
torch_ipex::cpu::punica_sgmv_expand_forward_cpu,
133+
c10::DispatchKey::CPU);
134+
}
135+
71136
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
72137
IPEX_OP_REGISTER_DISPATCH(
73138
"punica_bgmv_expand_slice",
74139
torch_ipex::cpu::punica_bgmv_expand_slice_forward_cpu,
75140
c10::DispatchKey::CPU);
76141
}
77142

143+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
144+
IPEX_OP_REGISTER_DISPATCH(
145+
"punica_sgmv_expand_slice",
146+
torch_ipex::cpu::punica_sgmv_expand_slice_forward_cpu,
147+
c10::DispatchKey::CPU);
148+
}
149+
78150
} // namespace

csrc/cpu/aten/Punica.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,29 @@ void punica_bgmv_shrink(
1515
at::Tensor& indicies,
1616
const double scale);
1717

18+
void punica_sgmv_shrink(
19+
at::Tensor& out,
20+
at::Tensor& input,
21+
at::Tensor& weights,
22+
at::Tensor& indicies,
23+
at::Tensor& seq_lens,
24+
const double scale);
25+
1826
void punica_bgmv_expand(
1927
at::Tensor& out,
2028
at::Tensor& input,
2129
at::Tensor& weights,
2230
at::Tensor& indicies,
2331
bool add_inputs);
2432

33+
void punica_sgmv_expand(
34+
at::Tensor& out,
35+
at::Tensor& input,
36+
at::Tensor& weights,
37+
at::Tensor& indicies,
38+
at::Tensor& seq_lens,
39+
bool add_inputs);
40+
2541
void punica_bgmv_expand_slice(
2642
at::Tensor& out,
2743
at::Tensor& input,
@@ -30,6 +46,16 @@ void punica_bgmv_expand_slice(
3046
int64_t slice_offset,
3147
int64_t slice_size,
3248
bool add_inputs);
49+
50+
void punica_sgmv_expand_slice(
51+
at::Tensor& out,
52+
at::Tensor& input,
53+
at::Tensor& weights,
54+
at::Tensor& indicies,
55+
at::Tensor& seq_lens,
56+
int64_t slice_offset,
57+
int64_t slice_size,
58+
bool add_inputs);
3359
} // namespace
3460

3561
using punica_bgmv_shrink_fn = void (*)(
@@ -39,13 +65,29 @@ using punica_bgmv_shrink_fn = void (*)(
3965
at::Tensor& indicies,
4066
const double scale);
4167

68+
using punica_sgmv_shrink_fn = void (*)(
69+
at::Tensor& out,
70+
at::Tensor& input,
71+
at::Tensor& weights,
72+
at::Tensor& indicies,
73+
at::Tensor& seq_lens,
74+
const double scale);
75+
4276
using punica_bgmv_expand_fn = void (*)(
4377
at::Tensor& out,
4478
at::Tensor& input,
4579
at::Tensor& weights,
4680
at::Tensor& indicies,
4781
bool add_inputs);
4882

83+
using punica_sgmv_expand_fn = void (*)(
84+
at::Tensor& out,
85+
at::Tensor& input,
86+
at::Tensor& weights,
87+
at::Tensor& indicies,
88+
at::Tensor& seq_lens,
89+
bool add_inputs);
90+
4991
using punica_bgmv_expand_slice_fn = void (*)(
5092
at::Tensor& out,
5193
at::Tensor& input,
@@ -55,13 +97,31 @@ using punica_bgmv_expand_slice_fn = void (*)(
5597
int64_t slice_size,
5698
bool add_inputs);
5799

100+
using punica_sgmv_expand_slice_fn = void (*)(
101+
at::Tensor& out,
102+
at::Tensor& input,
103+
at::Tensor& weights,
104+
at::Tensor& indicies,
105+
at::Tensor& seq_lens,
106+
int64_t slice_offset,
107+
int64_t slice_size,
108+
bool add_inputs);
109+
58110
IPEX_DECLARE_DISPATCH(punica_bgmv_shrink_fn, punica_bgmv_shrink_kernel_stub);
59111

112+
IPEX_DECLARE_DISPATCH(punica_sgmv_shrink_fn, punica_sgmv_shrink_kernel_stub);
113+
60114
IPEX_DECLARE_DISPATCH(punica_bgmv_expand_fn, punica_bgmv_expand_kernel_stub);
61115

116+
IPEX_DECLARE_DISPATCH(punica_sgmv_expand_fn, punica_sgmv_expand_kernel_stub);
117+
62118
IPEX_DECLARE_DISPATCH(
63119
punica_bgmv_expand_slice_fn,
64120
punica_bgmv_expand_slice_kernel_stub);
65121

122+
IPEX_DECLARE_DISPATCH(
123+
punica_sgmv_expand_slice_fn,
124+
punica_sgmv_expand_slice_kernel_stub);
125+
66126
} // namespace cpu
67127
} // namespace torch_ipex

0 commit comments

Comments
 (0)