Skip to content

Commit 6e6175b

Browse files
authored
Merge branch 'main' into zhiyu/llama4-fp8-modelopt
2 parents 537b451 + 8364608 commit 6e6175b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+4415
-162
lines changed

docs/references/production_metrics.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
SGLang exposes the following metrics via Prometheus. The metrics are namespaced by `$name` (the model name).
44

5-
An example of the monitoring dashboard is available in [examples/monitoring/grafana.json](../examples/monitoring/grafana/dashboards/json/sglang-dashboard.json).
5+
An example of the monitoring dashboard is available in [examples/monitoring/grafana.json](https://github.com/sgl-project/sglang/blob/main/examples/monitoring/grafana/dashboards/json/sglang-dashboard.json).
66

77
Here is an example of the metrics:
88

python/sglang/srt/configs/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
593593
"Mistral3ForConditionalGeneration",
594594
"MultiModalityCausalLM",
595595
"MllamaForConditionalGeneration",
596+
"Qwen2AudioForConditionalGeneration",
596597
"Qwen2VLForConditionalGeneration",
597598
"Qwen2_5_VLForConditionalGeneration",
598599
"KimiVLForConditionalGeneration",

python/sglang/srt/conversation.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class SeparatorStyle(IntEnum):
5959
METAMATH = auto()
6060
DeepSeekVL2 = auto()
6161
QWEN2_VL_EMBED = auto()
62+
QWEN2_AUDIO = auto()
6263
GEMMA3 = auto()
6364
MPT = auto()
6465

@@ -350,6 +351,23 @@ def get_prompt(self) -> str:
350351
else:
351352
ret += role
352353
return ret
354+
elif self.sep_style == SeparatorStyle.QWEN2_AUDIO:
355+
ret = "" if system_prompt == "" else system_prompt + self.sep
356+
357+
counter = 1
358+
for role, message in self.messages:
359+
if message:
360+
while self.audio_token in message:
361+
message = message.replace(
362+
self.audio_token, self.audio_token.format(idx=counter), 1
363+
)
364+
counter += 1
365+
366+
ret += role + "\n" + message + self.sep
367+
else:
368+
ret += role + "\n"
369+
370+
return ret
353371
else:
354372
raise ValueError(f"Invalid style: {self.sep_style}")
355373

@@ -904,6 +922,20 @@ def generate_chat_conv(
904922
)
905923

906924

925+
register_conv_template(
926+
Conversation(
927+
name="qwen2-audio",
928+
system_template="<|im_start|>system\n{system_message}",
929+
system_message="You are a helpful assistant.",
930+
roles=("<|im_start|>user", "<|im_start|>assistant"),
931+
sep="<|im_end|>\n",
932+
sep_style=SeparatorStyle.QWEN2_AUDIO,
933+
stop_str=["<|im_end|>"],
934+
audio_token="Audio {idx}: <|audio_bos|><|AUDIO|><|audio_eos|>\n",
935+
)
936+
)
937+
938+
907939
@register_conv_template_matching_function
908940
def match_internvl(model_path: str):
909941
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
@@ -956,6 +988,8 @@ def match_qwen_chat_ml(model_path: str):
956988
return "gme-qwen2-vl"
957989
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
958990
return "qwen2-vl"
991+
if re.search(r"qwen.*audio", model_path, re.IGNORECASE):
992+
return "qwen2-audio"
959993
if re.search(
960994
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
961995
model_path,

python/sglang/srt/disaggregation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
def available_size(self):
7575
return len(self.free_slots)
7676

77-
def alloc(self) -> List[int]:
77+
def alloc(self) -> Optional[int]:
7878
if len(self.free_slots) == 0:
7979
return None
8080

python/sglang/srt/entrypoints/http_server.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,26 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
712712
return ORJSONResponse(content=response_data, status_code=200)
713713

714714

715+
@app.post("/pause_generation")
716+
async def pause_generation(request: Request):
717+
"""Pause generation."""
718+
await _global_state.tokenizer_manager.pause_generation()
719+
return ORJSONResponse(
720+
content={"message": "Generation paused successfully.", "status": "ok"},
721+
status_code=200,
722+
)
723+
724+
725+
@app.post("/continue_generation")
726+
async def continue_generation(request: Request):
727+
"""Continue generation."""
728+
await _global_state.tokenizer_manager.continue_generation()
729+
return ORJSONResponse(
730+
content={"message": "Generation continued successfully.", "status": "ok"},
731+
status_code=200,
732+
)
733+
734+
715735
##### OpenAI-compatible API endpoints #####
716736

717737

python/sglang/srt/hf_transformers_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
)
4343
from sglang.srt.configs.internvl import InternVLChatConfig
4444
from sglang.srt.connector import create_remote_connector
45-
from sglang.srt.utils import is_remote_url
45+
from sglang.srt.utils import is_remote_url, lru_cache_frozenset
4646

4747
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
4848
ChatGLMConfig.model_type: ChatGLMConfig,
@@ -103,6 +103,7 @@ def get_hf_text_config(config: PretrainedConfig):
103103
return config
104104

105105

106+
@lru_cache_frozenset(maxsize=32)
106107
def get_config(
107108
model: str,
108109
trust_remote_code: bool,

python/sglang/srt/layers/activation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@
4646
if _is_cuda:
4747
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
4848

49-
logger = logging.getLogger(__name__)
50-
5149
if is_npu():
5250
import torch_npu
5351

52+
logger = logging.getLogger(__name__)
53+
5454

5555
class SiluAndMul(CustomOp):
5656
def forward_native(self, x: torch.Tensor) -> torch.Tensor:

python/sglang/srt/layers/logits_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,8 @@ def _get_logits(
436436
if self.do_tensor_parallel_all_gather_dp_attn:
437437
logits_metadata.compute_dp_attention_metadata(hidden_states)
438438
hidden_states, local_hidden_states = (
439-
logits_metadata.gathered_buffer,
440-
hidden_states.clone(),
439+
torch.empty_like(logits_metadata.gathered_buffer),
440+
hidden_states,
441441
)
442442
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
443443

python/sglang/srt/managers/io_struct.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class SessionParams:
3939
rid: Optional[str] = None
4040
offset: Optional[int] = None
4141
replace: Optional[bool] = None
42+
drop_previous_output: Optional[bool] = None
4243

4344

4445
AudioDataItem = Union[str, Dict]
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import re
2+
from typing import List, Union
3+
4+
import torch
5+
6+
from sglang.srt.managers.multimodal_processors.base_processor import (
7+
BaseMultimodalProcessor,
8+
MultimodalSpecialTokens,
9+
)
10+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
11+
from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
12+
13+
14+
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
15+
models = [Qwen2AudioForConditionalGeneration]
16+
17+
def __init__(self, hf_config, server_args, _processor):
18+
super().__init__(hf_config, server_args, _processor)
19+
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
20+
self.AUDIO_TOKEN_REGEX = re.compile(
21+
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
22+
)
23+
24+
async def process_mm_data_async(
25+
self,
26+
image_data: List[Union[str, bytes]],
27+
input_text,
28+
request_obj,
29+
max_req_input_len,
30+
**kwargs,
31+
):
32+
audio_data = request_obj.audio_data
33+
if not isinstance(audio_data, list):
34+
audio_data = [audio_data]
35+
36+
base_output = self.load_mm_data(
37+
prompt=input_text,
38+
max_req_input_len=max_req_input_len,
39+
audio_data=audio_data,
40+
multimodal_tokens=MultimodalSpecialTokens(
41+
audio_token=self.AUDIO_TOKEN,
42+
audio_token_regex=self.AUDIO_TOKEN_REGEX,
43+
),
44+
)
45+
if base_output is None:
46+
return None
47+
48+
res = self.process_mm_data(
49+
input_text=base_output.input_text,
50+
audio=base_output.audios,
51+
)
52+
53+
# Collect special token ids
54+
tokenizer = self._processor.tokenizer
55+
audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
56+
audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
57+
audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
58+
59+
items = []
60+
input_ids = res["input_ids"].flatten()
61+
62+
if (
63+
"input_features" in res
64+
and res["input_features"] is not None
65+
and len(res["input_features"]) != 0
66+
):
67+
if audio_start_id is not None and audio_end_id is not None:
68+
audio_offsets = self.get_mm_items_offset_by_pair(
69+
input_ids=input_ids,
70+
mm_start_id=audio_start_id,
71+
mm_end_id=audio_end_id,
72+
)
73+
else:
74+
audio_offsets = None
75+
76+
input_lengths = res["feature_attention_mask"].sum(dim=-1)
77+
input_lengths = (input_lengths - 1) // 2 + 1
78+
output_lengths = (input_lengths - 2) // 2 + 1
79+
80+
item = MultimodalDataItem(
81+
audio_features=res["input_features"],
82+
audio_feature_lens=output_lengths,
83+
audio_offsets=audio_offsets,
84+
modality=Modality.AUDIO,
85+
)
86+
items += [item]
87+
88+
return {
89+
"mm_items": items,
90+
"input_ids": input_ids.tolist(),
91+
"audio_start_id": audio_start_id,
92+
"audio_token_id": audio_token_id,
93+
"audio_end_id": audio_end_id,
94+
}

0 commit comments

Comments
 (0)