@@ -59,7 +59,6 @@ __device__ void moe_fused_gate_impl(
59
59
int64_t topk,
60
60
int64_t num_fused_shared_experts,
61
61
double routed_scaling_factor,
62
- bool apply_routed_scaling_factor_on_output,
63
62
Params params) {
64
63
int tidx = threadIdx .x ;
65
64
int64_t thread_row =
@@ -249,9 +248,6 @@ __device__ void moe_fused_gate_impl(
249
248
for (int ii = 0 ; ii < topk; ++ii) {
250
249
int64_t const idx = topk * thread_row + ii;
251
250
output_ptr[idx] = output_ptr[idx] / output_sum;
252
- if (apply_routed_scaling_factor_on_output) {
253
- output_ptr[idx] *= routed_scaling_factor;
254
- }
255
251
}
256
252
}
257
253
}
@@ -286,8 +282,7 @@ __global__ void moe_fused_gate_kernel(
286
282
int64_t topk_group,
287
283
int64_t topk,
288
284
int64_t num_fused_shared_experts,
289
- double routed_scaling_factor,
290
- bool apply_routed_scaling_factor_on_output) {
285
+ double routed_scaling_factor) {
291
286
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
292
287
moe_fused_gate_impl<T>(
293
288
input,
@@ -299,7 +294,6 @@ __global__ void moe_fused_gate_kernel(
299
294
topk,
300
295
num_fused_shared_experts,
301
296
routed_scaling_factor,
302
- apply_routed_scaling_factor_on_output,
303
297
params);
304
298
}
305
299
@@ -320,8 +314,7 @@ __global__ void moe_fused_gate_kernel(
320
314
topk_group, \
321
315
topk, \
322
316
num_fused_shared_experts, \
323
- routed_scaling_factor, \
324
- apply_routed_scaling_factor_on_output); \
317
+ routed_scaling_factor); \
325
318
dispatched = true ; \
326
319
} while (0 )
327
320
@@ -349,8 +342,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
349
342
int64_t topk_group,
350
343
int64_t topk,
351
344
int64_t num_fused_shared_experts,
352
- double routed_scaling_factor,
353
- bool apply_routed_scaling_factor_on_output) {
345
+ double routed_scaling_factor) {
354
346
KernelParamsDynamic params;
355
347
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
356
348
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
@@ -369,7 +361,6 @@ __global__ void moe_fused_gate_kernel_dynamic(
369
361
topk,
370
362
num_fused_shared_experts,
371
363
routed_scaling_factor,
372
- apply_routed_scaling_factor_on_output,
373
364
params);
374
365
}
375
366
@@ -383,8 +374,7 @@ std::vector<at::Tensor> moe_fused_gate(
383
374
int64_t topk_group,
384
375
int64_t topk,
385
376
int64_t num_fused_shared_experts,
386
- double routed_scaling_factor,
387
- bool apply_routed_scaling_factor_on_output) {
377
+ double routed_scaling_factor) {
388
378
int64_t num_rows = input.size (0 );
389
379
int32_t num_experts = input.size (1 );
390
380
auto options = torch::TensorOptions ().dtype (torch::kFloat32 ).device (torch::kCUDA );
@@ -483,8 +473,7 @@ std::vector<at::Tensor> moe_fused_gate(
483
473
topk_group,
484
474
topk,
485
475
num_fused_shared_experts,
486
- routed_scaling_factor,
487
- apply_routed_scaling_factor_on_output);
476
+ routed_scaling_factor);
488
477
} else if (input.scalar_type () == at::kHalf ) {
489
478
moe_fused_gate_kernel_dynamic<float16_t ><<<num_blocks, block_dim, 0 , stream>>> (
490
479
input.data_ptr (),
@@ -497,8 +486,7 @@ std::vector<at::Tensor> moe_fused_gate(
497
486
topk_group,
498
487
topk,
499
488
num_fused_shared_experts,
500
- routed_scaling_factor,
501
- apply_routed_scaling_factor_on_output);
489
+ routed_scaling_factor);
502
490
} else if (input.scalar_type () == at::kFloat ) {
503
491
moe_fused_gate_kernel_dynamic<float32_t ><<<num_blocks, block_dim, 0 , stream>>> (
504
492
input.data_ptr (),
@@ -511,8 +499,7 @@ std::vector<at::Tensor> moe_fused_gate(
511
499
topk_group,
512
500
topk,
513
501
num_fused_shared_experts,
514
- routed_scaling_factor,
515
- apply_routed_scaling_factor_on_output);
502
+ routed_scaling_factor);
516
503
} else {
517
504
TORCH_CHECK (false , " Unsupported data type for moe_fused_gate" );
518
505
}
0 commit comments