|
| 1 | +# coding=utf-8 |
| 2 | +# Adapted from |
| 3 | +# https://github.com/huggingface/transformers/blob/1d45d90e5d1552eccb6d8cc9b7bba283ccefb808/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py |
| 4 | +# Copyright 2024 The Qwen team. |
| 5 | +# Copyright 2023 The vLLM team. |
| 6 | +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. |
| 7 | +# |
| 8 | +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX |
| 9 | +# and OPT implementations in this library. It has been modified from its |
| 10 | +# original forms to accommodate minor architectural differences compared |
| 11 | +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. |
| 12 | +# |
| 13 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 14 | +# you may not use this file except in compliance with the License. |
| 15 | +# You may obtain a copy of the License at |
| 16 | +# |
| 17 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 18 | +# |
| 19 | +# Unless required by applicable law or agreed to in writing, software |
| 20 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 21 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 22 | +# See the License for the specific language governing permissions and |
| 23 | +# limitations under the License. |
| 24 | +"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" |
| 25 | +import logging |
| 26 | +import math |
| 27 | +from functools import lru_cache, partial |
| 28 | +from typing import Any, Iterable, List, Optional, Tuple, Type, TypedDict |
| 29 | + |
| 30 | +import torch |
| 31 | +import torch.nn as nn |
| 32 | +import torch.nn.functional as F |
| 33 | +from einops import rearrange |
| 34 | +from transformers import AutoTokenizer, Qwen2AudioEncoderConfig, Qwen2Config |
| 35 | +from transformers.activations import ACT2FN |
| 36 | +from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioConfig |
| 37 | +from transformers.models.qwen2_audio.modeling_qwen2_audio import ( |
| 38 | + Qwen2AudioEncoder, |
| 39 | + Qwen2AudioMultiModalProjector, |
| 40 | +) |
| 41 | + |
| 42 | +from sglang.srt.hf_transformers_utils import get_processor |
| 43 | +from sglang.srt.layers.activation import QuickGELU |
| 44 | +from sglang.srt.layers.attention.vision import VisionAttention |
| 45 | +from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear |
| 46 | +from sglang.srt.layers.logits_processor import LogitsProcessor |
| 47 | +from sglang.srt.layers.pooler import Pooler, PoolingType |
| 48 | +from sglang.srt.layers.quantization.base_config import QuantizationConfig |
| 49 | +from sglang.srt.layers.utils import get_layer_id |
| 50 | +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead |
| 51 | +from sglang.srt.managers.mm_utils import ( |
| 52 | + MultiModalityDataPaddingPatternMultimodalTokens, |
| 53 | + general_mm_embed_routine, |
| 54 | +) |
| 55 | +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs |
| 56 | +from sglang.srt.model_executor.forward_batch_info import ForwardBatch |
| 57 | +from sglang.srt.model_loader.weight_utils import default_weight_loader |
| 58 | +from sglang.srt.models.qwen2 import Qwen2ForCausalLM |
| 59 | +from sglang.srt.utils import add_prefix |
| 60 | + |
| 61 | +logger = logging.getLogger(__name__) |
| 62 | + |
| 63 | + |
| 64 | +class Qwen2AudioForConditionalGeneration(nn.Module): |
| 65 | + # BitandBytes specific attributes |
| 66 | + default_bitsandbytes_target_modules = [ |
| 67 | + ".gate_proj.", |
| 68 | + ".down_proj.", |
| 69 | + ".up_proj.", |
| 70 | + ".q_proj.", |
| 71 | + ".k_proj.", |
| 72 | + ".v_proj.", |
| 73 | + ".o_proj.", |
| 74 | + ] |
| 75 | + bitsandbytes_stacked_params_mapping = { |
| 76 | + # shard_name, weight_name, index |
| 77 | + "q_proj": ("qkv_proj", 0), |
| 78 | + "k_proj": ("qkv_proj", 1), |
| 79 | + "v_proj": ("qkv_proj", 2), |
| 80 | + "gate_proj": ("gate_up_proj", 0), |
| 81 | + "up_proj": ("gate_up_proj", 1), |
| 82 | + } |
| 83 | + |
| 84 | + def __init__( |
| 85 | + self, |
| 86 | + config: Qwen2AudioConfig, |
| 87 | + quant_config: Optional[QuantizationConfig] = None, |
| 88 | + prefix: str = "", |
| 89 | + ) -> None: |
| 90 | + super().__init__() |
| 91 | + |
| 92 | + self.config = config |
| 93 | + |
| 94 | + if getattr(self.config, "audio_config", None) is None: |
| 95 | + self.config.audio_config = Qwen2AudioEncoderConfig( |
| 96 | + self.config._name_or_path |
| 97 | + ) |
| 98 | + |
| 99 | + if getattr(self.config, "text_config", None) is None: |
| 100 | + self.config.text_config = Qwen2Config(self.config._name_or_path) |
| 101 | + |
| 102 | + self.audio_tower = Qwen2AudioEncoder( |
| 103 | + config.audio_config, |
| 104 | + ) |
| 105 | + self.multi_modal_projector = Qwen2AudioMultiModalProjector(config) |
| 106 | + self.language_model = Qwen2ForCausalLM( |
| 107 | + config.text_config, quant_config, prefix=add_prefix("model", prefix) |
| 108 | + ) |
| 109 | + |
| 110 | + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): |
| 111 | + # Get all special token IDs for audio |
| 112 | + audio_token_id: int = getattr( |
| 113 | + mm_inputs, "audio_token_id", mm_inputs.im_token_id |
| 114 | + ) |
| 115 | + |
| 116 | + pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id]) |
| 117 | + return pattern.pad_input_tokens(input_ids, mm_inputs) |
| 118 | + |
| 119 | + def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: |
| 120 | + # Extract audio features from input items |
| 121 | + input_features = torch.cat([item.audio_features for item in items], dim=0).type( |
| 122 | + self.audio_tower.dtype |
| 123 | + ) |
| 124 | + |
| 125 | + audio_embeds = self.audio_tower(input_features).last_hidden_state |
| 126 | + audio_embeds = self.multi_modal_projector(audio_embeds) |
| 127 | + |
| 128 | + audio_feature_lens = torch.cat([item.audio_feature_lens for item in items]) |
| 129 | + new_embeds = [] |
| 130 | + for i, d in zip(audio_feature_lens, audio_embeds): |
| 131 | + new_embeds.append(d[: i.item()]) |
| 132 | + |
| 133 | + return torch.cat(new_embeds, dim=0) |
| 134 | + |
| 135 | + def forward( |
| 136 | + self, |
| 137 | + input_ids: torch.Tensor, |
| 138 | + positions: torch.Tensor, |
| 139 | + forward_batch: ForwardBatch, |
| 140 | + **kwargs: Any, |
| 141 | + ) -> torch.Tensor: |
| 142 | + hidden_states = general_mm_embed_routine( |
| 143 | + input_ids=input_ids, |
| 144 | + forward_batch=forward_batch, |
| 145 | + language_model=self.language_model, |
| 146 | + audio_data_embedding_func=self.get_audio_feature, |
| 147 | + positions=positions, |
| 148 | + ) |
| 149 | + |
| 150 | + return hidden_states |
| 151 | + |
| 152 | + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
| 153 | + stacked_params_mapping = [ |
| 154 | + # (param_name, shard_name, shard_id) |
| 155 | + ("qkv_proj", "q_proj", "q"), |
| 156 | + ("qkv_proj", "k_proj", "k"), |
| 157 | + ("qkv_proj", "v_proj", "v"), |
| 158 | + ("gate_up_proj", "gate_proj", 0), |
| 159 | + ("gate_up_proj", "up_proj", 1), |
| 160 | + ] |
| 161 | + params_dict = dict(self.named_parameters(remove_duplicate=False)) |
| 162 | + |
| 163 | + for name, loaded_weight in weights: |
| 164 | + if "rotary_emb.inv_freq" in name: |
| 165 | + continue |
| 166 | + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: |
| 167 | + # Models trained using ColossalAI may include these tensors in |
| 168 | + # the checkpoint. Skip them. |
| 169 | + continue |
| 170 | + |
| 171 | + if self.config.text_config.tie_word_embeddings and "lm_head.weight" in name: |
| 172 | + continue |
| 173 | + |
| 174 | + for param_name, weight_name, shard_id in stacked_params_mapping: |
| 175 | + if weight_name not in name or "audio_tower" in name: |
| 176 | + continue |
| 177 | + name_tmp = name.replace(weight_name, param_name) |
| 178 | + |
| 179 | + # Skip loading extra bias for GPTQ models. |
| 180 | + if name_tmp.endswith(".bias") and name_tmp not in params_dict: |
| 181 | + continue |
| 182 | + param = params_dict[name_tmp] |
| 183 | + weight_loader = param.weight_loader |
| 184 | + weight_loader(param, loaded_weight, shard_id) |
| 185 | + break |
| 186 | + else: |
| 187 | + try: |
| 188 | + # Skip loading extra bias for GPTQ models. |
| 189 | + if name.endswith(".bias") and name not in params_dict: |
| 190 | + continue |
| 191 | + param = params_dict[name] |
| 192 | + except KeyError: |
| 193 | + print(params_dict.keys()) |
| 194 | + raise |
| 195 | + |
| 196 | + weight_loader = getattr(param, "weight_loader", default_weight_loader) |
| 197 | + weight_loader(param, loaded_weight) |
| 198 | + |
| 199 | + |
| 200 | +EntryClass = Qwen2AudioForConditionalGeneration |
0 commit comments