28
28
topk_ids_logical_to_physical ,
29
29
)
30
30
from sglang .srt .managers .schedule_batch import global_server_args_dict
31
- from sglang .srt .utils import get_compiler_backend , is_cuda , is_hip
31
+ from sglang .srt .utils import (
32
+ cpu_has_amx_support ,
33
+ get_compiler_backend ,
34
+ is_cpu ,
35
+ is_cuda ,
36
+ is_hip ,
37
+ )
32
38
33
39
_is_cuda = is_cuda ()
34
40
_is_hip = is_hip ()
41
+ _is_cpu_amx_available = cpu_has_amx_support ()
42
+ _is_cpu = is_cpu ()
35
43
36
44
if _is_cuda :
37
45
from sgl_kernel import moe_fused_gate
40
48
from sgl_kernel import topk_softmax
41
49
42
50
43
- def fused_topk_native (
51
+ def fused_topk_torch_native (
44
52
hidden_states : torch .Tensor ,
45
53
gating_output : torch .Tensor ,
46
54
topk : int ,
@@ -61,6 +69,20 @@ def fused_topk_native(
61
69
return topk_weights , topk_ids
62
70
63
71
72
+ def fused_topk_cpu (
73
+ hidden_states : torch .Tensor ,
74
+ gating_output : torch .Tensor ,
75
+ topk : int ,
76
+ renormalize : bool ,
77
+ ):
78
+ return torch .ops .sgl_kernel .topk_softmax_cpu (
79
+ hidden_states = hidden_states ,
80
+ gating_output = gating_output ,
81
+ topk = topk ,
82
+ renormalize = renormalize ,
83
+ )
84
+
85
+
64
86
def fused_topk (
65
87
hidden_states : torch .Tensor ,
66
88
gating_output : torch .Tensor ,
@@ -115,7 +137,7 @@ def _fused_topk_postprocess(
115
137
116
138
# This is used by the Deepseek V2/V3/R1 series models
117
139
@torch .compile (dynamic = True , backend = get_compiler_backend ())
118
- def grouped_topk (
140
+ def grouped_topk_gpu (
119
141
hidden_states : torch .Tensor ,
120
142
gating_output : torch .Tensor ,
121
143
topk : int ,
@@ -171,6 +193,32 @@ def grouped_topk(
171
193
return topk_weights , topk_ids
172
194
173
195
196
+ def grouped_topk_cpu (
197
+ hidden_states : torch .Tensor ,
198
+ gating_output : torch .Tensor ,
199
+ topk : int ,
200
+ renormalize : bool ,
201
+ num_expert_group : int = 0 ,
202
+ topk_group : int = 0 ,
203
+ num_fused_shared_experts : int = 0 ,
204
+ routed_scaling_factor : Optional [float ] = None ,
205
+ num_token_non_padded : Optional [torch .Tensor ] = None ,
206
+ expert_location_dispatch_info : Optional [ExpertLocationDispatchInfo ] = None ,
207
+ ):
208
+ assert expert_location_dispatch_info is None
209
+ return torch .ops .sgl_kernel .grouped_topk_cpu (
210
+ hidden_states ,
211
+ gating_output ,
212
+ topk ,
213
+ renormalize ,
214
+ num_expert_group ,
215
+ topk_group ,
216
+ num_fused_shared_experts ,
217
+ routed_scaling_factor ,
218
+ num_token_non_padded ,
219
+ )
220
+
221
+
174
222
def biased_grouped_topk_impl (
175
223
hidden_states : torch .Tensor ,
176
224
gating_output : torch .Tensor ,
@@ -258,7 +306,7 @@ def _biased_grouped_topk_postprocess(
258
306
return topk_ids
259
307
260
308
261
- def biased_grouped_topk (
309
+ def biased_grouped_topk_gpu (
262
310
hidden_states : torch .Tensor ,
263
311
gating_output : torch .Tensor ,
264
312
correction_bias : torch .Tensor ,
@@ -322,6 +370,45 @@ def biased_grouped_topk(
322
370
)
323
371
324
372
373
+ def biased_grouped_topk_cpu (
374
+ hidden_states : torch .Tensor ,
375
+ gating_output : torch .Tensor ,
376
+ correction_bias : torch .Tensor ,
377
+ topk : int ,
378
+ renormalize : bool ,
379
+ num_expert_group : int = 0 ,
380
+ topk_group : int = 0 ,
381
+ compiled : bool = True ,
382
+ num_fused_shared_experts : int = 0 ,
383
+ routed_scaling_factor : Optional [float ] = None ,
384
+ num_token_non_padded : Optional [torch .Tensor ] = None ,
385
+ expert_location_dispatch_info : Optional [ExpertLocationDispatchInfo ] = None ,
386
+ ):
387
+ assert expert_location_dispatch_info is None
388
+ return torch .ops .sgl_kernel .biased_grouped_topk_cpu (
389
+ hidden_states ,
390
+ gating_output ,
391
+ correction_bias ,
392
+ topk ,
393
+ renormalize ,
394
+ num_expert_group ,
395
+ topk_group ,
396
+ num_fused_shared_experts ,
397
+ routed_scaling_factor ,
398
+ num_token_non_padded ,
399
+ )
400
+
401
+
402
+ if _is_cpu and _is_cpu_amx_available :
403
+ biased_grouped_topk = biased_grouped_topk_cpu
404
+ grouped_topk = grouped_topk_cpu
405
+ fused_topk_native = fused_topk_cpu
406
+ else :
407
+ biased_grouped_topk = biased_grouped_topk_gpu
408
+ grouped_topk = grouped_topk_gpu
409
+ fused_topk_native = fused_topk_torch_native
410
+
411
+
325
412
def select_experts (
326
413
hidden_states : torch .Tensor ,
327
414
router_logits : torch .Tensor ,
0 commit comments