Skip to content

Commit 537b451

Browse files
committed
address reviews, minor fix
1 parent e9f72af commit 537b451

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -649,16 +649,12 @@ def _load_w2(
649649
shard_size,
650650
not self.use_presharded_weights,
651651
)
652-
if not self.use_presharded_weights:
653-
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
654-
raise ValueError(
655-
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
656-
)
657-
loaded_weight = loaded_weight.narrow(
658-
shard_dim, shard_size * tp_rank, shard_size
659-
)
660652
else:
661653
if not self.use_presharded_weights:
654+
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
655+
raise ValueError(
656+
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
657+
)
662658
loaded_weight = loaded_weight.narrow(
663659
shard_dim, shard_size * tp_rank, shard_size
664660
)

python/sglang/srt/models/mllama4.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json as json_lib
2+
import logging
23
import os
34
from collections.abc import Iterable
45
from typing import List, Optional, Set, Tuple
@@ -27,6 +28,8 @@
2728
)
2829
from sglang.srt.utils import add_prefix
2930

31+
logger = logging.getLogger(__name__)
32+
3033

3134
class Llama4ForConditionalGeneration(nn.Module):
3235
packed_modules_mapping = {
@@ -46,6 +49,11 @@ def __init__(
4649

4750
# Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
4851
self.has_vision = self._has_vision_weights(config)
52+
if not self.has_vision:
53+
logger.warning(
54+
"No vision weights found in checkpoint. Model will run in text-only mode. "
55+
"Multimodal capabilities (image processing) will be unavailable."
56+
)
4957

5058
if self.has_vision:
5159
self.vision_model = Llama4VisionModel(config.vision_config)
@@ -225,12 +233,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
225233
)
226234

227235
for name, loaded_weight in weights:
228-
if not self._should_load_weight(name):
236+
if self._should_skip_weight(name):
229237
continue
230238

231239
name = self._transform_weight_name(name)
232-
if name is None:
233-
continue
234240

235241
if "vision" not in name:
236242
name, loaded_weight = self.permute_qk_weight_for_rotary(
@@ -252,9 +258,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
252258

253259
self._handle_default_weight(name, loaded_weight, params_dict)
254260

255-
def _should_load_weight(self, name: str) -> bool:
256-
"""Check if we should load this weight."""
257-
return not ("vision" in name and not self.has_vision)
261+
def _should_skip_weight(self, name: str) -> bool:
262+
"""Check if we should skip loading this weight."""
263+
return "vision" in name and not self.has_vision
258264

259265
def _transform_weight_name(self, name: str) -> str:
260266
"""Transform weight name by adding language_model prefix if needed."""

0 commit comments

Comments
 (0)