@@ -59,6 +59,7 @@ __device__ void moe_fused_gate_impl(
59
59
int64_t topk,
60
60
int64_t num_fused_shared_experts,
61
61
double routed_scaling_factor,
62
+ bool apply_routed_scaling_factor_on_output,
62
63
Params params) {
63
64
int tidx = threadIdx .x ;
64
65
int64_t thread_row =
@@ -248,6 +249,9 @@ __device__ void moe_fused_gate_impl(
248
249
for (int ii = 0 ; ii < topk; ++ii) {
249
250
int64_t const idx = topk * thread_row + ii;
250
251
output_ptr[idx] = output_ptr[idx] / output_sum;
252
+ if (apply_routed_scaling_factor_on_output) {
253
+ output_ptr[idx] *= routed_scaling_factor;
254
+ }
251
255
}
252
256
}
253
257
}
@@ -282,7 +286,8 @@ __global__ void moe_fused_gate_kernel(
282
286
int64_t topk_group,
283
287
int64_t topk,
284
288
int64_t num_fused_shared_experts,
285
- double routed_scaling_factor) {
289
+ double routed_scaling_factor,
290
+ bool apply_routed_scaling_factor_on_output) {
286
291
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
287
292
moe_fused_gate_impl<T>(
288
293
input,
@@ -294,6 +299,7 @@ __global__ void moe_fused_gate_kernel(
294
299
topk,
295
300
num_fused_shared_experts,
296
301
routed_scaling_factor,
302
+ apply_routed_scaling_factor_on_output,
297
303
params);
298
304
}
299
305
@@ -314,7 +320,8 @@ __global__ void moe_fused_gate_kernel(
314
320
topk_group, \
315
321
topk, \
316
322
num_fused_shared_experts, \
317
- routed_scaling_factor); \
323
+ routed_scaling_factor, \
324
+ apply_routed_scaling_factor_on_output); \
318
325
dispatched = true ; \
319
326
} while (0 )
320
327
@@ -342,7 +349,8 @@ __global__ void moe_fused_gate_kernel_dynamic(
342
349
int64_t topk_group,
343
350
int64_t topk,
344
351
int64_t num_fused_shared_experts,
345
- double routed_scaling_factor) {
352
+ double routed_scaling_factor,
353
+ bool apply_routed_scaling_factor_on_output) {
346
354
KernelParamsDynamic params;
347
355
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
348
356
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
@@ -361,6 +369,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
361
369
topk,
362
370
num_fused_shared_experts,
363
371
routed_scaling_factor,
372
+ apply_routed_scaling_factor_on_output,
364
373
params);
365
374
}
366
375
@@ -374,7 +383,8 @@ std::vector<at::Tensor> moe_fused_gate(
374
383
int64_t topk_group,
375
384
int64_t topk,
376
385
int64_t num_fused_shared_experts,
377
- double routed_scaling_factor) {
386
+ double routed_scaling_factor,
387
+ bool apply_routed_scaling_factor_on_output) {
378
388
int64_t num_rows = input.size (0 );
379
389
int32_t num_experts = input.size (1 );
380
390
auto options = torch::TensorOptions ().dtype (torch::kFloat32 ).device (torch::kCUDA );
@@ -473,7 +483,8 @@ std::vector<at::Tensor> moe_fused_gate(
473
483
topk_group,
474
484
topk,
475
485
num_fused_shared_experts,
476
- routed_scaling_factor);
486
+ routed_scaling_factor,
487
+ apply_routed_scaling_factor_on_output);
477
488
} else if (input.scalar_type () == at::kHalf ) {
478
489
moe_fused_gate_kernel_dynamic<float16_t ><<<num_blocks, block_dim, 0 , stream>>> (
479
490
input.data_ptr (),
@@ -486,7 +497,8 @@ std::vector<at::Tensor> moe_fused_gate(
486
497
topk_group,
487
498
topk,
488
499
num_fused_shared_experts,
489
- routed_scaling_factor);
500
+ routed_scaling_factor,
501
+ apply_routed_scaling_factor_on_output);
490
502
} else if (input.scalar_type () == at::kFloat ) {
491
503
moe_fused_gate_kernel_dynamic<float32_t ><<<num_blocks, block_dim, 0 , stream>>> (
492
504
input.data_ptr (),
@@ -499,7 +511,8 @@ std::vector<at::Tensor> moe_fused_gate(
499
511
topk_group,
500
512
topk,
501
513
num_fused_shared_experts,
502
- routed_scaling_factor);
514
+ routed_scaling_factor,
515
+ apply_routed_scaling_factor_on_output);
503
516
} else {
504
517
TORCH_CHECK (false , " Unsupported data type for moe_fused_gate" );
505
518
}
0 commit comments