@@ -7,46 +7,97 @@ namespace torch_ipex {
7
7
namespace cpu {
8
8
9
9
IPEX_DEFINE_DISPATCH (punica_bgmv_shrink_kernel_stub);
10
+ IPEX_DEFINE_DISPATCH (punica_sgmv_shrink_kernel_stub);
10
11
IPEX_DEFINE_DISPATCH (punica_bgmv_expand_kernel_stub);
12
+ IPEX_DEFINE_DISPATCH (punica_sgmv_expand_kernel_stub);
11
13
IPEX_DEFINE_DISPATCH (punica_bgmv_expand_slice_kernel_stub);
14
+ IPEX_DEFINE_DISPATCH (punica_sgmv_expand_slice_kernel_stub);
12
15
13
- void punica_bgmv_shrink_forward_cpu (
16
+ at::Tensor punica_bgmv_shrink_forward_cpu (
14
17
at::Tensor& out,
15
18
at::Tensor& input,
16
19
at::Tensor& weights,
17
20
at::Tensor& indicies,
18
21
const double scale) {
19
- return punica_bgmv_shrink_kernel_stub (
20
- kCPU , out, input, weights, indicies, scale) ;
22
+ punica_bgmv_shrink_kernel_stub (kCPU , out, input, weights, indicies, scale);
23
+ return out;
21
24
}
22
25
23
- void punica_bgmv_expand_forward_cpu (
26
+ at::Tensor punica_sgmv_shrink_forward_cpu (
27
+ at::Tensor& out,
28
+ at::Tensor& input,
29
+ at::Tensor& weights,
30
+ at::Tensor& indicies,
31
+ at::Tensor& seq_lens,
32
+ const double scale) {
33
+ punica_sgmv_shrink_kernel_stub (
34
+ kCPU , out, input, weights, indicies, seq_lens, scale);
35
+ return out;
36
+ }
37
+
38
+ at::Tensor punica_bgmv_expand_forward_cpu (
24
39
at::Tensor& out,
25
40
at::Tensor& input,
26
41
at::Tensor& weights,
27
42
at::Tensor& indicies,
28
43
bool add_inputs) {
29
- return punica_bgmv_expand_kernel_stub (
44
+ punica_bgmv_expand_kernel_stub (
30
45
kCPU , out, input, weights, indicies, add_inputs);
46
+ return out;
47
+ }
48
+
49
+ at::Tensor punica_sgmv_expand_forward_cpu (
50
+ at::Tensor& out,
51
+ at::Tensor& input,
52
+ at::Tensor& weights,
53
+ at::Tensor& indicies,
54
+ at::Tensor& seq_lens,
55
+ bool add_inputs) {
56
+ punica_sgmv_expand_kernel_stub (
57
+ kCPU , out, input, weights, indicies, seq_lens, add_inputs);
58
+ return out;
59
+ }
60
+
61
+ at::Tensor punica_bgmv_expand_slice_forward_cpu (
62
+ at::Tensor& out,
63
+ at::Tensor& input,
64
+ at::Tensor& weights,
65
+ at::Tensor& indicies,
66
+ int64_t slice_offset,
67
+ int64_t slice_size,
68
+ bool add_inputs) {
69
+ punica_bgmv_expand_slice_kernel_stub (
70
+ kCPU ,
71
+ out,
72
+ input,
73
+ weights,
74
+ indicies,
75
+ slice_offset,
76
+ slice_size,
77
+ add_inputs);
78
+ return out;
31
79
}
32
80
33
- void punica_bgmv_expand_slice_forward_cpu (
81
+ at::Tensor punica_sgmv_expand_slice_forward_cpu (
34
82
at::Tensor& out,
35
83
at::Tensor& input,
36
84
at::Tensor& weights,
37
85
at::Tensor& indicies,
86
+ at::Tensor& seq_lens,
38
87
int64_t slice_offset,
39
88
int64_t slice_size,
40
89
bool add_inputs) {
41
- return punica_bgmv_expand_slice_kernel_stub (
90
+ punica_sgmv_expand_slice_kernel_stub (
42
91
kCPU ,
43
92
out,
44
93
input,
45
94
weights,
46
95
indicies,
96
+ seq_lens,
47
97
slice_offset,
48
98
slice_size,
49
99
add_inputs);
100
+ return out;
50
101
}
51
102
52
103
} // namespace cpu
@@ -61,18 +112,39 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
61
112
c10::DispatchKey::CPU);
62
113
}
63
114
115
+ TORCH_LIBRARY_FRAGMENT (torch_ipex, m) {
116
+ IPEX_OP_REGISTER_DISPATCH (
117
+ " punica_sgmv_shrink" ,
118
+ torch_ipex::cpu::punica_sgmv_shrink_forward_cpu,
119
+ c10::DispatchKey::CPU);
120
+ }
121
+
64
122
TORCH_LIBRARY_FRAGMENT (torch_ipex, m) {
65
123
IPEX_OP_REGISTER_DISPATCH (
66
124
" punica_bgmv_expand" ,
67
125
torch_ipex::cpu::punica_bgmv_expand_forward_cpu,
68
126
c10::DispatchKey::CPU);
69
127
}
70
128
129
+ TORCH_LIBRARY_FRAGMENT (torch_ipex, m) {
130
+ IPEX_OP_REGISTER_DISPATCH (
131
+ " punica_sgmv_expand" ,
132
+ torch_ipex::cpu::punica_sgmv_expand_forward_cpu,
133
+ c10::DispatchKey::CPU);
134
+ }
135
+
71
136
TORCH_LIBRARY_FRAGMENT (torch_ipex, m) {
72
137
IPEX_OP_REGISTER_DISPATCH (
73
138
" punica_bgmv_expand_slice" ,
74
139
torch_ipex::cpu::punica_bgmv_expand_slice_forward_cpu,
75
140
c10::DispatchKey::CPU);
76
141
}
77
142
143
+ TORCH_LIBRARY_FRAGMENT (torch_ipex, m) {
144
+ IPEX_OP_REGISTER_DISPATCH (
145
+ " punica_sgmv_expand_slice" ,
146
+ torch_ipex::cpu::punica_sgmv_expand_slice_forward_cpu,
147
+ c10::DispatchKey::CPU);
148
+ }
149
+
78
150
} // namespace
0 commit comments