Skip to content

Commit 0227fcf

Browse files
merrymercychenxijun1029
authored andcommitted
Move mem_fraction_static adjustment for multimodal models to server_args.py & Fix session control & Other cleanups (sgl-project#7748)
1 parent 67ce788 commit 0227fcf

File tree

16 files changed

+338
-136
lines changed

16 files changed

+338
-136
lines changed

python/sglang/srt/hf_transformers_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
)
4343
from sglang.srt.configs.internvl import InternVLChatConfig
4444
from sglang.srt.connector import create_remote_connector
45-
from sglang.srt.utils import is_remote_url
45+
from sglang.srt.utils import is_remote_url, lru_cache_frozenset
4646

4747
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
4848
ChatGLMConfig.model_type: ChatGLMConfig,
@@ -103,6 +103,7 @@ def get_hf_text_config(config: PretrainedConfig):
103103
return config
104104

105105

106+
@lru_cache_frozenset(maxsize=32)
106107
def get_config(
107108
model: str,
108109
trust_remote_code: bool,

python/sglang/srt/layers/activation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@
4646
if _is_cuda:
4747
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
4848

49-
logger = logging.getLogger(__name__)
50-
5149
if is_npu():
5250
import torch_npu
5351

52+
logger = logging.getLogger(__name__)
53+
5454

5555
class SiluAndMul(CustomOp):
5656
def forward_native(self, x: torch.Tensor) -> torch.Tensor:

python/sglang/srt/managers/io_struct.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class SessionParams:
3939
rid: Optional[str] = None
4040
offset: Optional[int] = None
4141
replace: Optional[bool] = None
42+
drop_previous_output: Optional[bool] = None
4243

4344

4445
AudioDataItem = Union[str, Dict]

python/sglang/srt/managers/schedule_batch.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ class MultimodalDataItem:
203203

204204
# the real data, pixel_values or audio_features
205205
# data: Union[List[torch.Tensor], List[np.ndarray]]
206-
pixel_values: Union[torch.Tensor, np.ndarray] = None
206+
pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
207207
audio_features: Union[torch.Tensor, np.ndarray] = None
208208
audio_feature_lens: Optional[List[torch.Tensor]] = None
209209
audio_offsets: Optional[List[Tuple[int, int]]] = None
@@ -244,15 +244,16 @@ def set_pad_value(self):
244244
"""
245245
from sglang.srt.managers.mm_utils import hash_feature
246246

247-
if self.precomputed_features is not None:
248-
self.hash = hash_feature(self.precomputed_features)
249-
elif self.is_audio():
250-
if self.audio_features is not None:
251-
self.hash = hash_feature(self.audio_features)
252-
elif self.input_features is not None:
253-
self.hash = hash_feature(self.input_features)
254-
else:
255-
self.hash = hash_feature(self.pixel_values)
247+
if self.hash is None:
248+
if self.precomputed_features is not None:
249+
self.hash = hash_feature(self.precomputed_features)
250+
elif self.is_audio():
251+
if self.audio_features is not None:
252+
self.hash = hash_feature(self.audio_features)
253+
elif self.input_features is not None:
254+
self.hash = hash_feature(self.input_features)
255+
else:
256+
self.hash = hash_feature(self.pixel_values)
256257

257258
assert self.hash is not None
258259
self.pad_value = self.hash % (1 << 30)
@@ -295,6 +296,13 @@ def from_dict(obj: dict):
295296
ret.validate()
296297
return ret
297298

299+
def merge(self, other):
300+
self.pixel_values += other.pixel_values
301+
self.image_sizes += other.image_sizes
302+
self.image_offsets += other.image_offsets
303+
self.hash = hash((self.hash, other.hash))
304+
self.set_pad_value()
305+
298306

299307
@dataclasses.dataclass
300308
class MultimodalInputs:

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@ def handle_generate_request(
11001100
recv_req.session_params is not None
11011101
and recv_req.session_params.id is not None
11021102
):
1103-
req.finished_reason = FINISH_ABORT(
1103+
req.set_finish_with_abort(
11041104
f"Invalid request: session id {recv_req.session_params.id} does not exist"
11051105
)
11061106
self._add_request_to_queue(req)

python/sglang/srt/managers/session_controller.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _str_helper(self, prefix=""):
5454
prefix += " -- " + self.childs[0].req.rid
5555
ret = self.childs[0]._str_helper(prefix)
5656
for child in self.childs[1:]:
57-
prefix = " " * len(origin_prefix) + r" \- " + child.req.rid
57+
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
5858
ret += child._str_helper(prefix)
5959
return ret
6060

@@ -106,14 +106,22 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
106106
last_req.origin_input_ids
107107
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
108108
)
109+
110+
if session_params.drop_previous_output:
111+
input_ids = last_req.origin_input_ids[:]
112+
109113
if session_params.offset and session_params.offset != 0:
110114
input_ids = input_ids[: session_params.offset] + req.input_ids
111115
else:
112116
input_ids += req.input_ids
117+
113118
input_ids_unpadded = (
114119
last_req.origin_input_ids_unpadded
115120
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
116121
)
122+
if session_params.drop_previous_output:
123+
input_ids_unpadded = last_req.origin_input_ids_unpadded[:]
124+
117125
if session_params.offset and session_params.offset != 0:
118126
input_ids_unpadded = (
119127
input_ids_unpadded[: session_params.offset] + req.input_ids
@@ -138,10 +146,11 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
138146
token_ids_logprob=req.token_ids_logprob,
139147
)
140148
if last_req is not None:
141-
new_req.multimodal_inputs = last_req.mm_inputs
149+
new_req.multimodal_inputs = last_req.multimodal_inputs
142150
new_req.tokenizer = tokenizer
151+
143152
if abort:
144-
new_req.to_abort = True
153+
new_req.set_finish_with_abort("Invalid request session id")
145154
else:
146155
new_req_node = SessionReqNode(new_req, last_req_node)
147156
self.req_nodes[req.rid] = new_req_node

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,7 @@ def get_log_request_metadata(self):
11481148
[
11491149
"text",
11501150
"output_ids",
1151+
"embedding",
11511152
]
11521153
)
11531154
elif self.log_requests_level == 1:
@@ -1166,6 +1167,7 @@ def get_log_request_metadata(self):
11661167
[
11671168
"text",
11681169
"output_ids",
1170+
"embedding",
11691171
]
11701172
)
11711173
elif self.log_requests_level == 2:

python/sglang/srt/mem_cache/multimodal_cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
2424
self.current_size += data_size
2525
return True
2626

27+
def has(self, mm_hash: int) -> bool:
28+
return mm_hash in self.mm_cache
29+
2730
def get(self, mm_hash: int) -> torch.Tensor:
2831
return self.mm_cache.get(mm_hash)
2932

python/sglang/srt/model_executor/model_runner.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -451,11 +451,6 @@ def model_specific_adjustment(self):
451451
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
452452

453453
if self.is_multimodal:
454-
self.mem_fraction_static *= 0.90
455-
logger.info(
456-
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
457-
f"because this is a multimodal model."
458-
)
459454
if not self.is_multimodal_chunked_prefill_supported:
460455
server_args.chunked_prefill_size = -1
461456
logger.info(

python/sglang/srt/models/qwen3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
get_pp_group,
1212
get_tensor_model_parallel_rank,
1313
get_tensor_model_parallel_world_size,
14-
split_tensor_along_last_dim,
15-
tensor_model_parallel_all_gather,
1614
)
1715
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
1816
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size

0 commit comments

Comments
 (0)