Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a3a1515
ckpt: fix loader for unsloth phi-4-mini model
yyihuang Apr 2, 2025
dd0b3c3
add empty bnb map for phi-4-mini, revert prev
yyihuang Apr 3, 2025
4a60f9f
Merge branch 'main' of github.com:sgl-project/sglang into unsloth-phi…
yyihuang Apr 3, 2025
98f3361
rm empty dict, update bnb mapping to bypass replace at up_proj
yyihuang Apr 3, 2025
8642d77
update phi-4-mini bnb offset
yyihuang Apr 4, 2025
3fadea6
Merge branch 'main' of github.com:sgl-project/sglang into unsloth-phi…
yyihuang Apr 4, 2025
add3e5b
fmt
yyihuang Apr 4, 2025
a0bb665
upd doc
yyihuang Apr 4, 2025
5d2eeaa
Merge branch 'main' of github.com:sgl-project/sglang into unsloth-phi…
yyihuang Apr 4, 2025
64f8197
add test
yyihuang Apr 4, 2025
2d7f5db
add test load format bnb
yyihuang Apr 4, 2025
da06f6c
Merge branch 'main' of github.com:sgl-project/sglang into unsloth-phi…
yyihuang Apr 4, 2025
e4861d3
Merge branch 'main' of github.com:sgl-project/sglang into unsloth-phi…
yyihuang Apr 5, 2025
8f0fd7d
Merge branch 'main' into unsloth-phi-4-mini
hnyls2002 Apr 9, 2025
55ee6fc
Merge branch 'main' of github.com:sgl-project/sglang into unsloth-phi…
yyihuang Apr 12, 2025
a2b8bb2
Merge branch 'main' into unsloth-phi-4-mini
zhaochenyang20 Apr 13, 2025
a6cd0d7
Merge branch 'main' into unsloth-phi-4-mini
yyihuang Apr 13, 2025
b2b991a
add unsloth to test suite
yyihuang Apr 13, 2025
c8f71a8
Merge branch 'main' into unsloth-phi-4-mini
yyihuang Apr 13, 2025
83d32ab
disable unsloth ci
yyihuang Apr 13, 2025
ec5f1f0
Merge branch 'main' into unsloth-phi-4-mini
yyihuang Apr 13, 2025
c496e3f
Merge branch 'main' into unsloth-phi-4-mini
yyihuang Apr 13, 2025
5532b91
Merge branch 'main' into unsloth-phi-4-mini
zhaochenyang20 Apr 13, 2025
88f7f49
Merge branch 'main' into unsloth-phi-4-mini
yyihuang Apr 14, 2025
4eb7611
fmt
yyihuang Apr 14, 2025
9ba6756
Merge branch 'main' into unsloth-phi-4-mini
zhyncs Apr 14, 2025
0f219d9
Merge branch 'main' into unsloth-phi-4-mini
zhaochenyang20 Apr 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""

import itertools
import logging
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -61,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):


def adjust_bitsandbytes_4bit_shard(
param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
) -> Tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

total, _ = qkv_offsets["total"]
orig_offset, orig_size = qkv_offsets[loaded_shard_id]
total, _ = shard_offsets["total"]
orig_offset, orig_size = shard_offsets[loaded_shard_id]

quantized_total = param.data.shape[0]
quantized_offset = orig_offset * quantized_total // total
Expand Down Expand Up @@ -573,6 +574,8 @@ def weight_loader(
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None)

use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
Expand All @@ -585,6 +588,17 @@ def weight_loader(
param, shard_size, shard_offset
)

if use_bitsandbytes_4bit:
index = list(itertools.accumulate([0] + self.output_sizes))
orig_offsets = {
str(i): (index[i], size)
for i, size in enumerate(self.output_sizes)
}
orig_offsets["total"] = (self.output_size, 0)
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_offsets, str(shard_id)
)

loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
)
Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,11 @@ class LlamaForCausalLM(nn.Module):
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
".q_proj": (".qkv_proj", 0),
".k_proj": (".qkv_proj", 1),
".v_proj": (".qkv_proj", 2),
".gate_proj": (".gate_up_proj", 0),
".up_proj": (".gate_up_proj", 1),
}

def __init__(
Expand Down
213 changes: 213 additions & 0 deletions test/srt/models/test_unsloth_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)


class TestUnslothPhi4(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/phi-4"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.78)


class TestUnslothPhi4Bnb4bit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/phi-4-bnb-4bit"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--load-format",
"bitsandbytes",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.75)


class TestUnslothPhi4UnslothBnb4bit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/phi-4-unsloth-bnb-4bit"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--load-format",
"bitsandbytes",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.75)


class TestUnslothPhi4MiniInstruct(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/Phi-4-mini-instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.65)


class TestUnslothPhi4MiniBnb4bit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/Phi-4-mini-instruct-bnb-4bit"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--load-format",
"bitsandbytes",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.6)


class TestUnslothPhi4MiniUnslothBnb4bit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--load-format",
"bitsandbytes",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.6)


if __name__ == "__main__":
unittest.main()
Loading