Skip to content

Commit 8364608

Browse files
authored
add model: qwen2-audio (#7596)
1 parent da3890e commit 8364608

File tree

5 files changed

+330
-0
lines changed

5 files changed

+330
-0
lines changed

python/sglang/srt/configs/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
593593
"Mistral3ForConditionalGeneration",
594594
"MultiModalityCausalLM",
595595
"MllamaForConditionalGeneration",
596+
"Qwen2AudioForConditionalGeneration",
596597
"Qwen2VLForConditionalGeneration",
597598
"Qwen2_5_VLForConditionalGeneration",
598599
"KimiVLForConditionalGeneration",

python/sglang/srt/conversation.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class SeparatorStyle(IntEnum):
5959
METAMATH = auto()
6060
DeepSeekVL2 = auto()
6161
QWEN2_VL_EMBED = auto()
62+
QWEN2_AUDIO = auto()
6263
GEMMA3 = auto()
6364
MPT = auto()
6465

@@ -350,6 +351,23 @@ def get_prompt(self) -> str:
350351
else:
351352
ret += role
352353
return ret
354+
elif self.sep_style == SeparatorStyle.QWEN2_AUDIO:
355+
ret = "" if system_prompt == "" else system_prompt + self.sep
356+
357+
counter = 1
358+
for role, message in self.messages:
359+
if message:
360+
while self.audio_token in message:
361+
message = message.replace(
362+
self.audio_token, self.audio_token.format(idx=counter), 1
363+
)
364+
counter += 1
365+
366+
ret += role + "\n" + message + self.sep
367+
else:
368+
ret += role + "\n"
369+
370+
return ret
353371
else:
354372
raise ValueError(f"Invalid style: {self.sep_style}")
355373

@@ -904,6 +922,20 @@ def generate_chat_conv(
904922
)
905923

906924

925+
register_conv_template(
926+
Conversation(
927+
name="qwen2-audio",
928+
system_template="<|im_start|>system\n{system_message}",
929+
system_message="You are a helpful assistant.",
930+
roles=("<|im_start|>user", "<|im_start|>assistant"),
931+
sep="<|im_end|>\n",
932+
sep_style=SeparatorStyle.QWEN2_AUDIO,
933+
stop_str=["<|im_end|>"],
934+
audio_token="Audio {idx}: <|audio_bos|><|AUDIO|><|audio_eos|>\n",
935+
)
936+
)
937+
938+
907939
@register_conv_template_matching_function
908940
def match_internvl(model_path: str):
909941
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
@@ -956,6 +988,8 @@ def match_qwen_chat_ml(model_path: str):
956988
return "gme-qwen2-vl"
957989
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
958990
return "qwen2-vl"
991+
if re.search(r"qwen.*audio", model_path, re.IGNORECASE):
992+
return "qwen2-audio"
959993
if re.search(
960994
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
961995
model_path,
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import re
2+
from typing import List, Union
3+
4+
import torch
5+
6+
from sglang.srt.managers.multimodal_processors.base_processor import (
7+
BaseMultimodalProcessor,
8+
MultimodalSpecialTokens,
9+
)
10+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
11+
from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
12+
13+
14+
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
15+
models = [Qwen2AudioForConditionalGeneration]
16+
17+
def __init__(self, hf_config, server_args, _processor):
18+
super().__init__(hf_config, server_args, _processor)
19+
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
20+
self.AUDIO_TOKEN_REGEX = re.compile(
21+
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
22+
)
23+
24+
async def process_mm_data_async(
25+
self,
26+
image_data: List[Union[str, bytes]],
27+
input_text,
28+
request_obj,
29+
max_req_input_len,
30+
**kwargs,
31+
):
32+
audio_data = request_obj.audio_data
33+
if not isinstance(audio_data, list):
34+
audio_data = [audio_data]
35+
36+
base_output = self.load_mm_data(
37+
prompt=input_text,
38+
max_req_input_len=max_req_input_len,
39+
audio_data=audio_data,
40+
multimodal_tokens=MultimodalSpecialTokens(
41+
audio_token=self.AUDIO_TOKEN,
42+
audio_token_regex=self.AUDIO_TOKEN_REGEX,
43+
),
44+
)
45+
if base_output is None:
46+
return None
47+
48+
res = self.process_mm_data(
49+
input_text=base_output.input_text,
50+
audio=base_output.audios,
51+
)
52+
53+
# Collect special token ids
54+
tokenizer = self._processor.tokenizer
55+
audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
56+
audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
57+
audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
58+
59+
items = []
60+
input_ids = res["input_ids"].flatten()
61+
62+
if (
63+
"input_features" in res
64+
and res["input_features"] is not None
65+
and len(res["input_features"]) != 0
66+
):
67+
if audio_start_id is not None and audio_end_id is not None:
68+
audio_offsets = self.get_mm_items_offset_by_pair(
69+
input_ids=input_ids,
70+
mm_start_id=audio_start_id,
71+
mm_end_id=audio_end_id,
72+
)
73+
else:
74+
audio_offsets = None
75+
76+
input_lengths = res["feature_attention_mask"].sum(dim=-1)
77+
input_lengths = (input_lengths - 1) // 2 + 1
78+
output_lengths = (input_lengths - 2) // 2 + 1
79+
80+
item = MultimodalDataItem(
81+
audio_features=res["input_features"],
82+
audio_feature_lens=output_lengths,
83+
audio_offsets=audio_offsets,
84+
modality=Modality.AUDIO,
85+
)
86+
items += [item]
87+
88+
return {
89+
"mm_items": items,
90+
"input_ids": input_ids.tolist(),
91+
"audio_start_id": audio_start_id,
92+
"audio_token_id": audio_token_id,
93+
"audio_end_id": audio_end_id,
94+
}

python/sglang/srt/models/qwen2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ def __init__(
425425
quant_config=quant_config,
426426
prefix=add_prefix("lm_head", prefix),
427427
)
428+
428429
else:
429430
# ranks other than the last rank will have a placeholder layer
430431
self.lm_head = PPMissingLayer()
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# coding=utf-8
2+
# Adapted from
3+
# https://github.com/huggingface/transformers/blob/1d45d90e5d1552eccb6d8cc9b7bba283ccefb808/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
4+
# Copyright 2024 The Qwen team.
5+
# Copyright 2023 The vLLM team.
6+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
7+
#
8+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
9+
# and OPT implementations in this library. It has been modified from its
10+
# original forms to accommodate minor architectural differences compared
11+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12+
#
13+
# Licensed under the Apache License, Version 2.0 (the "License");
14+
# you may not use this file except in compliance with the License.
15+
# You may obtain a copy of the License at
16+
#
17+
# http://www.apache.org/licenses/LICENSE-2.0
18+
#
19+
# Unless required by applicable law or agreed to in writing, software
20+
# distributed under the License is distributed on an "AS IS" BASIS,
21+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22+
# See the License for the specific language governing permissions and
23+
# limitations under the License.
24+
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
25+
import logging
26+
import math
27+
from functools import lru_cache, partial
28+
from typing import Any, Iterable, List, Optional, Tuple, Type, TypedDict
29+
30+
import torch
31+
import torch.nn as nn
32+
import torch.nn.functional as F
33+
from einops import rearrange
34+
from transformers import AutoTokenizer, Qwen2AudioEncoderConfig, Qwen2Config
35+
from transformers.activations import ACT2FN
36+
from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioConfig
37+
from transformers.models.qwen2_audio.modeling_qwen2_audio import (
38+
Qwen2AudioEncoder,
39+
Qwen2AudioMultiModalProjector,
40+
)
41+
42+
from sglang.srt.hf_transformers_utils import get_processor
43+
from sglang.srt.layers.activation import QuickGELU
44+
from sglang.srt.layers.attention.vision import VisionAttention
45+
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
46+
from sglang.srt.layers.logits_processor import LogitsProcessor
47+
from sglang.srt.layers.pooler import Pooler, PoolingType
48+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
49+
from sglang.srt.layers.utils import get_layer_id
50+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
51+
from sglang.srt.managers.mm_utils import (
52+
MultiModalityDataPaddingPatternMultimodalTokens,
53+
general_mm_embed_routine,
54+
)
55+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
56+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
57+
from sglang.srt.model_loader.weight_utils import default_weight_loader
58+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
59+
from sglang.srt.utils import add_prefix
60+
61+
logger = logging.getLogger(__name__)
62+
63+
64+
class Qwen2AudioForConditionalGeneration(nn.Module):
65+
# BitandBytes specific attributes
66+
default_bitsandbytes_target_modules = [
67+
".gate_proj.",
68+
".down_proj.",
69+
".up_proj.",
70+
".q_proj.",
71+
".k_proj.",
72+
".v_proj.",
73+
".o_proj.",
74+
]
75+
bitsandbytes_stacked_params_mapping = {
76+
# shard_name, weight_name, index
77+
"q_proj": ("qkv_proj", 0),
78+
"k_proj": ("qkv_proj", 1),
79+
"v_proj": ("qkv_proj", 2),
80+
"gate_proj": ("gate_up_proj", 0),
81+
"up_proj": ("gate_up_proj", 1),
82+
}
83+
84+
def __init__(
85+
self,
86+
config: Qwen2AudioConfig,
87+
quant_config: Optional[QuantizationConfig] = None,
88+
prefix: str = "",
89+
) -> None:
90+
super().__init__()
91+
92+
self.config = config
93+
94+
if getattr(self.config, "audio_config", None) is None:
95+
self.config.audio_config = Qwen2AudioEncoderConfig(
96+
self.config._name_or_path
97+
)
98+
99+
if getattr(self.config, "text_config", None) is None:
100+
self.config.text_config = Qwen2Config(self.config._name_or_path)
101+
102+
self.audio_tower = Qwen2AudioEncoder(
103+
config.audio_config,
104+
)
105+
self.multi_modal_projector = Qwen2AudioMultiModalProjector(config)
106+
self.language_model = Qwen2ForCausalLM(
107+
config.text_config, quant_config, prefix=add_prefix("model", prefix)
108+
)
109+
110+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
111+
# Get all special token IDs for audio
112+
audio_token_id: int = getattr(
113+
mm_inputs, "audio_token_id", mm_inputs.im_token_id
114+
)
115+
116+
pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id])
117+
return pattern.pad_input_tokens(input_ids, mm_inputs)
118+
119+
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
120+
# Extract audio features from input items
121+
input_features = torch.cat([item.audio_features for item in items], dim=0).type(
122+
self.audio_tower.dtype
123+
)
124+
125+
audio_embeds = self.audio_tower(input_features).last_hidden_state
126+
audio_embeds = self.multi_modal_projector(audio_embeds)
127+
128+
audio_feature_lens = torch.cat([item.audio_feature_lens for item in items])
129+
new_embeds = []
130+
for i, d in zip(audio_feature_lens, audio_embeds):
131+
new_embeds.append(d[: i.item()])
132+
133+
return torch.cat(new_embeds, dim=0)
134+
135+
def forward(
136+
self,
137+
input_ids: torch.Tensor,
138+
positions: torch.Tensor,
139+
forward_batch: ForwardBatch,
140+
**kwargs: Any,
141+
) -> torch.Tensor:
142+
hidden_states = general_mm_embed_routine(
143+
input_ids=input_ids,
144+
forward_batch=forward_batch,
145+
language_model=self.language_model,
146+
audio_data_embedding_func=self.get_audio_feature,
147+
positions=positions,
148+
)
149+
150+
return hidden_states
151+
152+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
153+
stacked_params_mapping = [
154+
# (param_name, shard_name, shard_id)
155+
("qkv_proj", "q_proj", "q"),
156+
("qkv_proj", "k_proj", "k"),
157+
("qkv_proj", "v_proj", "v"),
158+
("gate_up_proj", "gate_proj", 0),
159+
("gate_up_proj", "up_proj", 1),
160+
]
161+
params_dict = dict(self.named_parameters(remove_duplicate=False))
162+
163+
for name, loaded_weight in weights:
164+
if "rotary_emb.inv_freq" in name:
165+
continue
166+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
167+
# Models trained using ColossalAI may include these tensors in
168+
# the checkpoint. Skip them.
169+
continue
170+
171+
if self.config.text_config.tie_word_embeddings and "lm_head.weight" in name:
172+
continue
173+
174+
for param_name, weight_name, shard_id in stacked_params_mapping:
175+
if weight_name not in name or "audio_tower" in name:
176+
continue
177+
name_tmp = name.replace(weight_name, param_name)
178+
179+
# Skip loading extra bias for GPTQ models.
180+
if name_tmp.endswith(".bias") and name_tmp not in params_dict:
181+
continue
182+
param = params_dict[name_tmp]
183+
weight_loader = param.weight_loader
184+
weight_loader(param, loaded_weight, shard_id)
185+
break
186+
else:
187+
try:
188+
# Skip loading extra bias for GPTQ models.
189+
if name.endswith(".bias") and name not in params_dict:
190+
continue
191+
param = params_dict[name]
192+
except KeyError:
193+
print(params_dict.keys())
194+
raise
195+
196+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
197+
weight_loader(param, loaded_weight)
198+
199+
200+
EntryClass = Qwen2AudioForConditionalGeneration

0 commit comments

Comments
 (0)