Skip to content
This repository was archived by the owner on Aug 4, 2025. It is now read-only.

Commit 22b8fe3

Browse files
yanbing-jchunyuan-wjianan-gusdp
authored andcommitted
Update python API of activation, topk, norm and rope and remove vllm dependency (sgl-project#6614)
Co-authored-by: Wu, Chunyuan <[email protected]> Co-authored-by: jianan-gu <[email protected]> Co-authored-by: sdp <[email protected]>
1 parent fb3750d commit 22b8fe3

23 files changed

+270
-56
lines changed

docker/Dockerfile.xeon

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ RUN git clone https://github.com/sgl-project/sglang.git && \
3939
cp pyproject_cpu.toml pyproject.toml && \
4040
pip install -v .
4141

42+
ENV SGLANG_USE_CPU_ENGINE=1
4243
ENV LD_PRELOAD=/sgl-workspace/miniforge3/lib/libiomp5.so:/sgl-workspace/miniforge3/lib/libtcmalloc.so:/sgl-workspace/miniforge3/lib/libtbbmalloc.so.2
4344

4445
WORKDIR /sgl-workspace/sglang

python/sglang/srt/custom_op.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from torch import nn
22

3-
from sglang.srt.utils import is_cuda, is_hip
3+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
44

55
_is_cuda = is_cuda()
66
_is_hip = is_hip()
7+
_is_cpu = is_cpu()
8+
_is_cpu_amx_available = cpu_has_amx_support()
79

810

911
class CustomOp(nn.Module):
@@ -75,5 +77,7 @@ def dispatch_forward(self):
7577
return self.forward_cuda
7678
elif _is_hip:
7779
return self.forward_hip
80+
elif _is_cpu and _is_cpu_amx_available:
81+
return self.forward_cpu
7882
else:
7983
return self.forward_native

python/sglang/srt/layers/activation.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,19 @@
2929
get_tensor_model_parallel_world_size,
3030
)
3131
from sglang.srt.layers.quantization.base_config import QuantizationConfig
32-
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
32+
from sglang.srt.utils import (
33+
cpu_has_amx_support,
34+
is_cpu,
35+
is_cuda,
36+
is_npu,
37+
set_weight_attrs,
38+
)
3339
from sglang.utils import resolve_obj_by_qualname
3440

3541
_is_cuda = is_cuda()
3642
_is_npu = is_npu()
43+
_is_cpu_amx_available = cpu_has_amx_support()
44+
_is_cpu = is_cpu()
3745

3846
if _is_cuda:
3947
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
@@ -53,6 +61,15 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
5361
silu_and_mul(x, out)
5462
return out
5563

64+
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
65+
if _is_cpu_amx_available:
66+
d = x.shape[-1] // 2
67+
output_shape = x.shape[:-1] + (d,)
68+
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
69+
return out
70+
else:
71+
return self.forward_native(x)
72+
5673

5774
class GeluAndMul(CustomOp):
5875
def __init__(self, approximate="tanh"):
@@ -185,8 +202,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
185202
return nn.Identity()
186203

187204

188-
if not _is_cuda and not _is_npu:
205+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
189206
logger.info(
190-
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
207+
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
191208
)
192209
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul

python/sglang/srt/layers/layernorm.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,21 @@
2020
import torch.nn as nn
2121

2222
from sglang.srt.custom_op import CustomOp
23-
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, is_npu
23+
from sglang.srt.utils import (
24+
cpu_has_amx_support,
25+
get_bool_env_var,
26+
is_cpu,
27+
is_cuda,
28+
is_hip,
29+
is_npu,
30+
)
2431

2532
_is_cuda = is_cuda()
2633
_is_hip = is_hip()
2734
_is_npu = is_npu()
2835
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
36+
_is_cpu_amx_available = cpu_has_amx_support()
37+
_is_cpu = is_cpu()
2938

3039
if _is_cuda:
3140
from sgl_kernel import (
@@ -122,6 +131,23 @@ def forward_native(
122131
else:
123132
return x, residual
124133

134+
def forward_cpu(
135+
self,
136+
x: torch.Tensor,
137+
residual: Optional[torch.Tensor] = None,
138+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
139+
if _is_cpu_amx_available:
140+
if residual is not None:
141+
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
142+
x, residual, self.weight.data, self.variance_epsilon
143+
)
144+
return x, residual
145+
return torch.ops.sgl_kernel.rmsnorm_cpu(
146+
x, self.weight.data, self.variance_epsilon
147+
)
148+
else:
149+
return self.forward_native(x, residual)
150+
125151

126152
class GemmaRMSNorm(CustomOp):
127153
def __init__(
@@ -188,7 +214,7 @@ def extra_repr(self):
188214
return f"{tuple(self.weight.shape)}, eps={self.eps}"
189215

190216

191-
if not (_is_cuda or _is_hip or _is_npu):
217+
if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
192218
logger.info(
193219
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
194220
)

python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
sglang_per_token_group_quant_int8,
2626
)
2727
from sglang.srt.utils import (
28+
cpu_has_amx_support,
2829
direct_register_custom_op,
2930
get_bool_env_var,
3031
get_device_name,
32+
is_cpu,
3133
is_cuda,
3234
is_hip,
3335
log_info_on_rank0,
@@ -36,9 +38,13 @@
3638

3739
_is_hip = is_hip()
3840
_is_cuda = is_cuda()
41+
_is_cpu_amx_available = cpu_has_amx_support()
42+
_is_cpu = is_cpu()
3943

4044
if _is_cuda:
4145
from sgl_kernel import gelu_and_mul, silu_and_mul
46+
elif _is_cpu and _is_cpu_amx_available:
47+
pass
4248
else:
4349
from vllm import _custom_ops as vllm_ops
4450
from vllm._custom_ops import scaled_fp8_quant

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,11 @@ def forward_cpu(
241241
num_fused_shared_experts: int = 0,
242242
custom_routing_function: Optional[Callable] = None,
243243
correction_bias: Optional[torch.Tensor] = None,
244+
activation: str = "silu",
245+
apply_router_weight_on_input: bool = False,
244246
inplace: bool = True,
247+
no_combine: bool = False,
248+
routed_scaling_factor: Optional[float] = None,
245249
) -> torch.Tensor:
246250
return moe_forward_native(
247251
layer,
@@ -260,7 +264,7 @@ def forward_cpu(
260264
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
261265
raise NotImplementedError("The TPU backend currently does not support MoE.")
262266

263-
forward_native = forward_cuda
267+
forward_native = forward_cpu
264268

265269

266270
class FusedMoE(torch.nn.Module):

python/sglang/srt/layers/moe/topk.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,18 @@
2828
topk_ids_logical_to_physical,
2929
)
3030
from sglang.srt.managers.schedule_batch import global_server_args_dict
31-
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
31+
from sglang.srt.utils import (
32+
cpu_has_amx_support,
33+
get_compiler_backend,
34+
is_cpu,
35+
is_cuda,
36+
is_hip,
37+
)
3238

3339
_is_cuda = is_cuda()
3440
_is_hip = is_hip()
41+
_is_cpu_amx_available = cpu_has_amx_support()
42+
_is_cpu = is_cpu()
3543

3644
if _is_cuda:
3745
from sgl_kernel import moe_fused_gate
@@ -40,7 +48,7 @@
4048
from sgl_kernel import topk_softmax
4149

4250

43-
def fused_topk_native(
51+
def fused_topk_torch_native(
4452
hidden_states: torch.Tensor,
4553
gating_output: torch.Tensor,
4654
topk: int,
@@ -61,6 +69,20 @@ def fused_topk_native(
6169
return topk_weights, topk_ids
6270

6371

72+
def fused_topk_cpu(
73+
hidden_states: torch.Tensor,
74+
gating_output: torch.Tensor,
75+
topk: int,
76+
renormalize: bool,
77+
):
78+
return torch.ops.sgl_kernel.topk_softmax_cpu(
79+
hidden_states=hidden_states,
80+
gating_output=gating_output,
81+
topk=topk,
82+
renormalize=renormalize,
83+
)
84+
85+
6486
def fused_topk(
6587
hidden_states: torch.Tensor,
6688
gating_output: torch.Tensor,
@@ -115,7 +137,7 @@ def _fused_topk_postprocess(
115137

116138
# This is used by the Deepseek V2/V3/R1 series models
117139
@torch.compile(dynamic=True, backend=get_compiler_backend())
118-
def grouped_topk(
140+
def grouped_topk_gpu(
119141
hidden_states: torch.Tensor,
120142
gating_output: torch.Tensor,
121143
topk: int,
@@ -171,6 +193,32 @@ def grouped_topk(
171193
return topk_weights, topk_ids
172194

173195

196+
def grouped_topk_cpu(
197+
hidden_states: torch.Tensor,
198+
gating_output: torch.Tensor,
199+
topk: int,
200+
renormalize: bool,
201+
num_expert_group: int = 0,
202+
topk_group: int = 0,
203+
num_fused_shared_experts: int = 0,
204+
routed_scaling_factor: Optional[float] = None,
205+
num_token_non_padded: Optional[torch.Tensor] = None,
206+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
207+
):
208+
assert expert_location_dispatch_info is None
209+
return torch.ops.sgl_kernel.grouped_topk_cpu(
210+
hidden_states,
211+
gating_output,
212+
topk,
213+
renormalize,
214+
num_expert_group,
215+
topk_group,
216+
num_fused_shared_experts,
217+
routed_scaling_factor,
218+
num_token_non_padded,
219+
)
220+
221+
174222
def biased_grouped_topk_impl(
175223
hidden_states: torch.Tensor,
176224
gating_output: torch.Tensor,
@@ -258,7 +306,7 @@ def _biased_grouped_topk_postprocess(
258306
return topk_ids
259307

260308

261-
def biased_grouped_topk(
309+
def biased_grouped_topk_gpu(
262310
hidden_states: torch.Tensor,
263311
gating_output: torch.Tensor,
264312
correction_bias: torch.Tensor,
@@ -322,6 +370,45 @@ def biased_grouped_topk(
322370
)
323371

324372

373+
def biased_grouped_topk_cpu(
374+
hidden_states: torch.Tensor,
375+
gating_output: torch.Tensor,
376+
correction_bias: torch.Tensor,
377+
topk: int,
378+
renormalize: bool,
379+
num_expert_group: int = 0,
380+
topk_group: int = 0,
381+
compiled: bool = True,
382+
num_fused_shared_experts: int = 0,
383+
routed_scaling_factor: Optional[float] = None,
384+
num_token_non_padded: Optional[torch.Tensor] = None,
385+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
386+
):
387+
assert expert_location_dispatch_info is None
388+
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
389+
hidden_states,
390+
gating_output,
391+
correction_bias,
392+
topk,
393+
renormalize,
394+
num_expert_group,
395+
topk_group,
396+
num_fused_shared_experts,
397+
routed_scaling_factor,
398+
num_token_non_padded,
399+
)
400+
401+
402+
if _is_cpu and _is_cpu_amx_available:
403+
biased_grouped_topk = biased_grouped_topk_cpu
404+
grouped_topk = grouped_topk_cpu
405+
fused_topk_native = fused_topk_cpu
406+
else:
407+
biased_grouped_topk = biased_grouped_topk_gpu
408+
grouped_topk = grouped_topk_gpu
409+
fused_topk_native = fused_topk_torch_native
410+
411+
325412
def select_experts(
326413
hidden_states: torch.Tensor,
327414
router_logits: torch.Tensor,

python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,18 @@
1414
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
1515
from sglang.srt.layers.quantization.utils import (
1616
all_close_1d,
17+
cpu_has_amx_support,
1718
per_tensor_dequantize,
1819
replace_parameter,
1920
)
20-
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
21+
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
2122

2223
_is_cuda = is_cuda()
2324
_is_npu = is_npu()
25+
_is_cpu_amx_available = cpu_has_amx_support()
26+
_is_cpu = is_cpu()
2427

25-
if not _is_cuda and not _is_npu:
28+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
2629
from vllm import _custom_ops as vllm_ops
2730
from vllm._custom_ops import scaled_fp8_quant
2831

python/sglang/srt/layers/quantization/fp8.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def dummy_func(*args, **kwargs):
6464
)
6565
from sglang.srt.layers.utils import is_sm100_supported
6666
from sglang.srt.utils import (
67+
cpu_has_amx_support,
6768
get_bool_env_var,
69+
is_cpu,
6870
is_cuda,
6971
is_hip,
7072
is_npu,
@@ -76,6 +78,8 @@ def dummy_func(*args, **kwargs):
7678
_is_hip = is_hip()
7779
_is_cuda = is_cuda()
7880
_is_npu = is_npu()
81+
_is_cpu_amx_available = cpu_has_amx_support()
82+
_is_cpu = is_cpu()
7983

8084
_is_fp8_fnuz = is_fp8_fnuz()
8185

@@ -88,7 +92,7 @@ def dummy_func(*args, **kwargs):
8892
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
8993
from aiter.ops.shuffle import shuffle_weight
9094

91-
if not _is_cuda and not _is_npu:
95+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
9296
from vllm._custom_ops import scaled_fp8_quant
9397

9498

python/sglang/srt/layers/quantization/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
import torch
77

88
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
9-
from sglang.srt.utils import is_cuda, is_npu
9+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
1010

1111
_is_cuda = is_cuda()
1212
_is_npu = is_npu()
13+
_is_cpu_amx_available = cpu_has_amx_support()
14+
_is_cpu = is_cpu()
1315

14-
if not _is_cuda and not _is_npu:
16+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
1517
from vllm._custom_ops import scaled_fp8_quant
1618

1719

0 commit comments

Comments
 (0)