Skip to content

Commit 877e35d

Browse files
authored
Add get_hidden_dim to qwen3.py for correct lora (#7312)
1 parent cbdfb77 commit 877e35d

File tree

5 files changed

+240
-2
lines changed

5 files changed

+240
-2
lines changed

python/sglang/srt/models/qwen3.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,30 @@ def __init__(
330330
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
331331
return self.model.get_input_embeddings(input_ids)
332332

333+
def get_hidden_dim(self, module_name: str) -> Tuple[int]:
334+
# return input_dim, output_dim
335+
if module_name in ["q_proj", "qkv_proj"]:
336+
return (
337+
self.config.hidden_size,
338+
self.config.head_dim * self.config.num_attention_heads,
339+
)
340+
elif module_name in ["o_proj"]:
341+
return (
342+
self.config.head_dim * self.config.num_attention_heads,
343+
self.config.hidden_size,
344+
)
345+
elif module_name in ["kv_proj"]:
346+
return (
347+
self.config.hidden_size,
348+
self.config.head_dim * self.config.num_key_value_heads,
349+
)
350+
elif module_name == "gate_up_proj":
351+
return self.config.hidden_size, self.config.intermediate_size
352+
elif module_name == "down_proj":
353+
return self.config.intermediate_size, self.config.hidden_size
354+
else:
355+
raise NotImplementedError()
356+
333357
@torch.no_grad()
334358
def forward(
335359
self,

python/sglang/test/runners.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,12 @@ def __init__(
134134
model_type: str = "generation",
135135
output_str_only: bool = False,
136136
trust_remote_code: bool = False,
137+
patch_model_do_sample_false: bool = False,
137138
):
138139
self.model_type = model_type
139140
self.output_str_only = output_str_only
140141
self.trust_remote_code = trust_remote_code
142+
self.patch_model_do_sample_false = patch_model_do_sample_false
141143

142144
self.in_queue = mp.Queue()
143145
self.out_queue = mp.Queue()
@@ -292,6 +294,7 @@ def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
292294
torch_dtype=torch_dtype,
293295
output_str_only=self.output_str_only,
294296
token_ids_logprob=token_ids_logprob,
297+
patch_model_do_sample_false=self.patch_model_do_sample_false,
295298
)
296299
)
297300
elif self.model_type == "embedding":
@@ -380,6 +383,7 @@ def forward_generation_raw(
380383
lora_paths: Optional[List[str]] = None,
381384
output_str_only: bool = False,
382385
token_ids_logprob: Optional[int] = None,
386+
patch_model_do_sample_false: Optional[bool] = False,
383387
) -> ModelOutput:
384388
output_strs = []
385389
top_input_logprobs = []
@@ -407,7 +411,8 @@ def forward_generation_raw(
407411
)
408412
else:
409413
model = base_model
410-
414+
if patch_model_do_sample_false:
415+
model.generation_config.do_sample = False
411416
outputs = model.generate(
412417
input_ids=input_ids,
413418
generation_config=GenerationConfig(

test/srt/models/lora/test_lora.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def ensure_reproducibility(self):
8484
torch.use_deterministic_algorithms(True)
8585

8686
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
87-
8887
for model_case in model_cases:
8988
for torch_dtype in TORCH_DTYPES:
9089
max_new_tokens = 32
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Copyright 2023-2025 SGLang Team
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ==============================================================================
14+
15+
import multiprocessing as mp
16+
import os
17+
import random
18+
import unittest
19+
from typing import List
20+
21+
from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
22+
23+
from sglang.test.runners import HFRunner, SRTRunner
24+
from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
25+
26+
LORA_MODELS_QWEN3 = [
27+
LoRAModelCase(
28+
base="Qwen/Qwen3-4B",
29+
adaptors=[
30+
LoRAAdaptor(
31+
name="nissenj/Qwen3-4B-lora-v2",
32+
prefill_tolerance=3e-1,
33+
),
34+
LoRAAdaptor(
35+
name="y9760210/Qwen3-4B-lora_model",
36+
prefill_tolerance=3e-1,
37+
),
38+
],
39+
max_loras_per_batch=2,
40+
),
41+
]
42+
43+
44+
TEST_MULTIPLE_BATCH_PROMPTS = [
45+
"""
46+
### Instruction:
47+
Tell me about llamas and alpacas
48+
### Response:
49+
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
50+
### Question 2:
51+
What do you know about llamas?
52+
### Answer:
53+
""",
54+
"""
55+
### Instruction:
56+
Write a poem about the transformers Python library.
57+
Mention the word "large language models" in that poem.
58+
### Response:
59+
The Transformers are large language models,
60+
They're used to make predictions on text.
61+
""",
62+
# "AI is a field of computer science focused on", TODO: Add it back after fixing its bug
63+
"Computer science is the study of",
64+
"Write a short story.",
65+
"What are the main components of a computer?",
66+
]
67+
68+
69+
class TestLoRA(CustomTestCase):
70+
71+
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
72+
for model_case in model_cases:
73+
for torch_dtype in TORCH_DTYPES:
74+
max_new_tokens = 10
75+
backend = "triton"
76+
base_path = model_case.base
77+
lora_adapter_paths = [a.name for a in model_case.adaptors]
78+
assert len(lora_adapter_paths) >= 2
79+
80+
batches = [
81+
(
82+
[
83+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
84+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
85+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
86+
],
87+
[
88+
None,
89+
lora_adapter_paths[0],
90+
lora_adapter_paths[1],
91+
],
92+
),
93+
(
94+
[
95+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
96+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
97+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
98+
],
99+
[
100+
lora_adapter_paths[0],
101+
None,
102+
lora_adapter_paths[1],
103+
],
104+
),
105+
(
106+
[
107+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
108+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
109+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
110+
],
111+
[lora_adapter_paths[0], lora_adapter_paths[1], None],
112+
),
113+
(
114+
[
115+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
116+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
117+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
118+
],
119+
[None, lora_adapter_paths[1], None],
120+
),
121+
(
122+
[
123+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
124+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
125+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
126+
],
127+
[None, None, None],
128+
),
129+
]
130+
131+
print(
132+
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---"
133+
)
134+
135+
# Initialize runners
136+
srt_runner = SRTRunner(
137+
base_path,
138+
torch_dtype=torch_dtype,
139+
model_type="generation",
140+
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
141+
max_loras_per_batch=len(lora_adapter_paths) + 1,
142+
lora_backend=backend,
143+
disable_radix_cache=True,
144+
)
145+
hf_runner = HFRunner(
146+
base_path,
147+
torch_dtype=torch_dtype,
148+
model_type="generation",
149+
patch_model_do_sample_false=True,
150+
)
151+
152+
with srt_runner, hf_runner:
153+
for i, (prompts, lora_paths) in enumerate(batches):
154+
print(
155+
f"\n--- Running Batch {i+1} --- prompts: {prompts}, lora_paths: {lora_paths}"
156+
)
157+
158+
srt_outputs = srt_runner.batch_forward(
159+
prompts,
160+
max_new_tokens=max_new_tokens,
161+
lora_paths=lora_paths,
162+
)
163+
164+
hf_outputs = hf_runner.forward(
165+
prompts,
166+
max_new_tokens=max_new_tokens,
167+
lora_paths=lora_paths,
168+
)
169+
170+
print("SRT outputs:", [s for s in srt_outputs.output_strs])
171+
print("HF outputs:", [s for s in hf_outputs.output_strs])
172+
173+
for srt_out, hf_out in zip(
174+
srt_outputs.output_strs, hf_outputs.output_strs
175+
):
176+
srt_str = srt_out.strip()
177+
hf_str = hf_out.strip()
178+
rouge_tol = model_case.rouge_l_tolerance
179+
rouge_score = calculate_rouge_l([srt_str], [hf_str])[0]
180+
if rouge_score < rouge_tol:
181+
raise AssertionError(
182+
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
183+
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'"
184+
)
185+
186+
print(f"--- Batch {i+1} Comparison Passed --- ")
187+
188+
def test_ci_lora_models(self):
189+
self._run_lora_multiple_batch_on_model_cases(LORA_MODELS_QWEN3)
190+
191+
def test_all_lora_models(self):
192+
if is_in_ci():
193+
return
194+
qwen_filtered_models = []
195+
for model_case in LORA_MODELS_QWEN3:
196+
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
197+
continue
198+
qwen_filtered_models.append(model_case)
199+
200+
self._run_lora_multiple_batch_on_model_cases(qwen_filtered_models)
201+
202+
203+
if __name__ == "__main__":
204+
try:
205+
mp.set_start_method("spawn")
206+
except RuntimeError:
207+
pass
208+
209+
unittest.main(warnings="ignore")

test/srt/run_suite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class TestFile:
1919
TestFile("models/lora/test_multi_lora_backend.py", 60),
2020
TestFile("models/lora/test_lora_cuda_graph.py", 250),
2121
TestFile("models/lora/test_lora_update.py", 800),
22+
TestFile("models/lora/test_lora_qwen3.py", 97),
2223
TestFile("models/test_embedding_models.py", 73),
2324
# TestFile("models/test_clip_models.py", 52),
2425
TestFile("models/test_encoder_embedding_models.py", 100),

0 commit comments

Comments
 (0)