@@ -150,6 +150,9 @@ constexpr int block_size_m() {
150
150
constexpr int block_size_n () {
151
151
return 8 * TILE_N;
152
152
}
153
+ constexpr int block_size_n2 () {
154
+ return 2 * TILE_N;
155
+ }
153
156
// convert to vnni format
154
157
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
155
158
//
@@ -170,45 +173,76 @@ inline void pack_vnni(
170
173
}
171
174
}
172
175
173
- at::Tensor convert_weight_packed_bf16 (at::Tensor& weight) {
176
+ std::tuple<at::Tensor, at::Tensor> convert_weight_packed_moe_bf16 (
177
+ at::Tensor& w1,
178
+ at::Tensor& w2) {
174
179
// weight : [E, OC, IC]
175
180
// w1 : [E, 2N, K]
176
181
// w2 : [E, K, N]
177
- CHECK_DIM (3 , weight);
178
- const auto st = weight.scalar_type ();
179
- const int E = weight.size (0 );
180
- const int OC = weight.size (1 );
181
- const int IC = weight.size (2 );
182
+ CHECK_DIM (3 , w1);
183
+ CHECK_DIM (3 , w2);
184
+ const auto st1 = w1.scalar_type ();
185
+ const auto st2 = w2.scalar_type ();
186
+ TORCH_CHECK (st1 == st2, " weight type mismatch" );
187
+ const int E1 = w1.size (0 );
188
+ const int E2 = w2.size (0 );
189
+ const int OC1 = w1.size (1 );
190
+ const int IC1 = w1.size (2 );
191
+ const int OC2 = w2.size (1 );
192
+ const int IC2 = w2.size (2 );
193
+ TORCH_CHECK (E1 == E2 && IC2 * 2 == OC1 && IC1 == OC2, " weight size mismatch" );
194
+ const int N = IC2;
195
+ const int K = IC1;
182
196
// we handle 2 TILE_N at a time.
183
- TORCH_CHECK (OC % TILE_N == 0 , " invalid weight out features " , OC);
184
- TORCH_CHECK (IC % TILE_K == 0 , " invalid weight input features " , IC);
185
- constexpr int BLOCK_N = block_size_n ();
197
+ TORCH_CHECK (N % TILE_N == 0 , " invalid weight out features " , N);
198
+ TORCH_CHECK (K % TILE_K == 0 , " invalid weight input features " , K);
199
+ int BLOCK_N = block_size_n ();
200
+ if (N < 2 * BLOCK_N) {
201
+ BLOCK_N = block_size_n2 ();
202
+ }
203
+
186
204
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
187
- auto packed_weight = at::empty ({E, OC, IC}, weight.options ());
188
- const int stride = OC * IC;
205
+ auto packed_weight1 = at::empty ({E1 , OC1, IC1}, w1.options ());
206
+ auto packed_weight2 = at::empty ({E2 , OC2, IC2}, w2.options ());
207
+ const int stride1 = OC1 * IC1;
208
+ const int stride2 = OC2 * IC2;
189
209
// TODO: add float8 support
190
210
TORCH_CHECK (
191
- st == at::kBFloat16 || st == at::kHalf ,
211
+ st1 == at::kBFloat16 || st1 == at::kHalf ,
192
212
" expect weight to be bfloat16 or float16." );
193
- AT_DISPATCH_REDUCED_FLOATING_TYPES (st, " conver_weight_packed_impl" , [&] {
194
- const scalar_t * w_data = weight.data_ptr <scalar_t >();
195
- scalar_t * packed_data = packed_weight.data_ptr <scalar_t >();
213
+ AT_DISPATCH_REDUCED_FLOATING_TYPES (st1, " conver_weight_packed_impl" , [&] {
214
+ const scalar_t * w_data1 = w1.data_ptr <scalar_t >();
215
+ const scalar_t * w_data2 = w2.data_ptr <scalar_t >();
216
+ scalar_t * packed_data1 = packed_weight1.data_ptr <scalar_t >();
217
+ scalar_t * packed_data2 = packed_weight2.data_ptr <scalar_t >();
196
218
// parallel on {E}
197
- at::parallel_for (0 , E, 0 , [&](int begin, int end) {
219
+ at::parallel_for (0 , E1 , 0 , [&](int begin, int end) {
220
+ for (int e = begin; e < end; ++e) {
221
+ for (int n = 0 ; n < OC1; n += BLOCK_N) {
222
+ int n_size = std::min (BLOCK_N, OC1 - n);
223
+ pack_vnni<scalar_t >(
224
+ packed_data1 + e * stride1 + n * IC1,
225
+ w_data1 + e * stride1 + n * IC1,
226
+ n_size,
227
+ IC1);
228
+ }
229
+ }
230
+ });
231
+ at::parallel_for (0 , E2 , 0 , [&](int begin, int end) {
198
232
for (int e = begin; e < end; ++e) {
199
- for (int n = 0 ; n < OC ; n += BLOCK_N) {
200
- int n_size = std::min (BLOCK_N, OC - n);
233
+ for (int n = 0 ; n < OC2 ; n += BLOCK_N) {
234
+ int n_size = std::min (BLOCK_N, OC2 - n);
201
235
pack_vnni<scalar_t >(
202
- packed_data + e * stride + n * IC ,
203
- w_data + e * stride + n * IC ,
236
+ packed_data2 + e * stride2 + n * IC2 ,
237
+ w_data2 + e * stride2 + n * IC2 ,
204
238
n_size,
205
- IC );
239
+ IC2 );
206
240
}
207
241
}
208
242
});
209
243
});
210
244
211
- return packed_weight ;
245
+ return std::make_tuple (packed_weight1, packed_weight2) ;
212
246
}
213
247
214
248
template <typename scalar_t , int SIZE>
@@ -485,10 +519,11 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
485
519
" grouped_topk(Tensor hidden_states, Tensor gating_output, \
486
520
int topk, bool renormalize, int num_expert_group, int topk_group, Tensor e_score_correction_bias, Tensor routed_scaling_factor) -> (Tensor, Tensor)" );
487
521
m.impl (" grouped_topk" , c10::DispatchKey::CPU, torch_ipex::cpu::grouped_topk);
488
- m.def (" convert_weight_packed_bf16(Tensor weight) -> Tensor" );
522
+ m.def (
523
+ " convert_weight_packed_moe_bf16(Tensor weight1, Tensor weight2) -> (Tensor, Tensor)" );
489
524
m.impl (
490
- " convert_weight_packed_bf16 " ,
525
+ " convert_weight_packed_moe_bf16 " ,
491
526
c10::DispatchKey::CPU,
492
- torch_ipex::cpu::convert_weight_packed_bf16 );
527
+ torch_ipex::cpu::convert_weight_packed_moe_bf16 );
493
528
}
494
529
} // namespace
0 commit comments