Skip to content

Commit 1bca79a

Browse files
jianan-guyanbing-j
authored andcommitted
adding origin rope and topk fusion frontend (#1)
1 parent d12e2eb commit 1bca79a

File tree

2 files changed

+45
-6
lines changed

2 files changed

+45
-6
lines changed

python/sglang/srt/layers/moe/topk.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -380,12 +380,21 @@ def select_experts(
380380
num_token_non_padded is None
381381
), "num_token_non_padded is not yet supported in fused_topk_native"
382382
assert expert_location_dispatch_info is None
383-
topk_weights, topk_ids = fused_topk_native(
384-
hidden_states=hidden_states,
385-
gating_output=router_logits,
386-
topk=top_k,
387-
renormalize=renormalize,
388-
)
383+
device = hidden_states.device
384+
if device == torch.device("cpu") and _is_cpu_amx:
385+
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
386+
hidden_states=hidden_states,
387+
gating_output=router_logits,
388+
topk=top_k,
389+
renormalize=renormalize,
390+
)
391+
else:
392+
topk_weights, topk_ids = fused_topk_native(
393+
hidden_states=hidden_states,
394+
gating_output=router_logits,
395+
topk=top_k,
396+
renormalize=renormalize,
397+
)
389398
elif custom_routing_function is None:
390399
assert (
391400
num_token_non_padded is None

python/sglang/srt/layers/rotary_embedding.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,16 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
118118
cache = torch.cat((cos, sin), dim=-1)
119119
return cache
120120

121+
def forward(self, *args, **kwargs):
122+
if torch.compiler.is_compiling():
123+
return self.forward_native(*args, **kwargs)
124+
if _is_cuda:
125+
return self.forward_cuda(*args, **kwargs)
126+
elif _is_cpu_amx:
127+
return self.forward_cpu(*args, **kwargs)
128+
else:
129+
return self.forward_native(*args, **kwargs)
130+
121131
def forward_native(
122132
self,
123133
positions: torch.Tensor,
@@ -148,6 +158,26 @@ def forward_native(
148158
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
149159
return query, key
150160

161+
def forward_cpu(
162+
self,
163+
positions: torch.Tensor,
164+
query: torch.Tensor,
165+
key: torch.Tensor,
166+
offsets: Optional[torch.Tensor] = None,
167+
) -> Tuple[torch.Tensor, torch.Tensor]:
168+
positions = torch.add(positions, offsets) if offsets is not None else positions
169+
if positions.device == torch.device("cpu") and _is_cpu_amx:
170+
return torch.ops.sgl_kernel.rotary_embedding_cpu(
171+
positions,
172+
query,
173+
key,
174+
self.head_size,
175+
self.cos_sin_cache,
176+
self.is_neox_style,
177+
)
178+
else:
179+
return self.forward_native(positions, query, key, offsets)
180+
151181
def forward_cuda(
152182
self,
153183
positions: torch.Tensor,

0 commit comments

Comments
 (0)