Skip to content

Commit c896e6e

Browse files
committed
move test func
1 parent 17a2e36 commit c896e6e

File tree

2 files changed

+100
-103
lines changed

2 files changed

+100
-103
lines changed

python/sglang/srt/layers/quantization/utils.py

Lines changed: 1 addition & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
22

33
from types import MappingProxyType
4-
from typing import List, Mapping, Optional, Tuple, Union
4+
from typing import List, Mapping, Tuple, Union
55

66
import numpy
77
import torch
@@ -210,99 +210,3 @@ def unpack_cols(
210210
q_res = q_res.contiguous()
211211

212212
return q_res
213-
214-
215-
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
216-
def quantize_weights(
217-
w: torch.Tensor,
218-
quant_type: ScalarType,
219-
group_size: Optional[int],
220-
zero_points: bool = False,
221-
ref_zero_points_after_scales: bool = False,
222-
):
223-
assert (
224-
quant_type.is_integer()
225-
), "Floating point quantization may work but has not been tested"
226-
assert not zero_points or group_size is not None, (
227-
"to have group zero points, group_size must be provided "
228-
"(-1 group_size is channelwise)"
229-
)
230-
231-
orig_device = w.device
232-
orig_type = w.dtype
233-
size_k, size_n = w.shape
234-
235-
assert w.is_floating_point(), "w must be float"
236-
237-
if group_size == -1:
238-
group_size = size_k
239-
240-
# Reshape to [groupsize, -1]
241-
if group_size is not None and group_size < size_k:
242-
w = w.reshape((-1, group_size, size_n))
243-
w = w.permute(1, 0, 2)
244-
w = w.reshape((group_size, -1))
245-
246-
# Compute scale for each group
247-
max_val = torch.max(w, 0, keepdim=True).values
248-
min_val = torch.min(w, 0, keepdim=True).values
249-
250-
max_q_val = quant_type.max()
251-
min_q_val = quant_type.min()
252-
253-
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
254-
maybe_w_zp = None
255-
if group_size is not None:
256-
if zero_points:
257-
assert not quant_type.is_signed() and quant_type.max() > 0
258-
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
259-
maybe_w_zp = (
260-
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
261-
)
262-
else:
263-
# If the bias is such that there are no possible negative/positive
264-
# values, set the max value to inf to avoid divide by 0
265-
w_s = torch.max(
266-
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
267-
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
268-
)
269-
270-
# Quantize
271-
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
272-
w_q = torch.clamp(w_q, min_q_val, max_q_val)
273-
274-
# Compute ref (dequantized)
275-
# For some kernels (namely Machete) the zero-points are applied after the
276-
# scales are applied, for this case computing the reference in similar way
277-
# allows us to use tighter error tolerances in our unit tests.
278-
if ref_zero_points_after_scales and maybe_w_zp is not None:
279-
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
280-
else:
281-
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
282-
283-
if quant_type.has_bias():
284-
w_q += quant_type.bias
285-
286-
# Restore original shapes
287-
if group_size is not None and group_size < size_k:
288-
289-
def reshape_w(w):
290-
w = w.reshape((group_size, -1, size_n))
291-
w = w.permute(1, 0, 2)
292-
w = w.reshape((size_k, size_n)).contiguous()
293-
return w
294-
295-
w_q = reshape_w(w_q)
296-
w_ref = reshape_w(w_ref)
297-
w_s = w_s.reshape((-1, size_n)).contiguous()
298-
299-
if maybe_w_zp is not None:
300-
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
301-
maybe_w_zp = maybe_w_zp.to(device=orig_device)
302-
303-
return (
304-
w_ref.to(device=orig_device),
305-
w_q.to(device=orig_device),
306-
w_s if group_size is not None else None,
307-
maybe_w_zp,
308-
)

sgl-kernel/tests/test_marlin_repack.py

Lines changed: 99 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,113 @@
11
import math
2+
from typing import Optional
23

34
import numpy as np
45
import pytest
56
import torch
67
from sgl_kernel import awq_marlin_repack
7-
from sgl_kernel.scalar_type import scalar_types
8+
from sgl_kernel.scalar_type import ScalarType, scalar_types
89

9-
from sglang.srt.layers.quantization.utils import (
10-
get_pack_factor,
11-
pack_cols,
12-
quantize_weights,
13-
)
10+
from sglang.srt.layers.quantization.utils import get_pack_factor, pack_cols
1411

1512
GPTQ_MARLIN_TILE = 16
1613

1714

15+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
16+
def quantize_weights(
17+
w: torch.Tensor,
18+
quant_type: ScalarType,
19+
group_size: Optional[int],
20+
zero_points: bool = False,
21+
ref_zero_points_after_scales: bool = False,
22+
):
23+
assert (
24+
quant_type.is_integer()
25+
), "Floating point quantization may work but has not been tested"
26+
assert not zero_points or group_size is not None, (
27+
"to have group zero points, group_size must be provided "
28+
"(-1 group_size is channelwise)"
29+
)
30+
31+
orig_device = w.device
32+
orig_type = w.dtype
33+
size_k, size_n = w.shape
34+
35+
assert w.is_floating_point(), "w must be float"
36+
37+
if group_size == -1:
38+
group_size = size_k
39+
40+
# Reshape to [groupsize, -1]
41+
if group_size is not None and group_size < size_k:
42+
w = w.reshape((-1, group_size, size_n))
43+
w = w.permute(1, 0, 2)
44+
w = w.reshape((group_size, -1))
45+
46+
# Compute scale for each group
47+
max_val = torch.max(w, 0, keepdim=True).values
48+
min_val = torch.min(w, 0, keepdim=True).values
49+
50+
max_q_val = quant_type.max()
51+
min_q_val = quant_type.min()
52+
53+
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
54+
maybe_w_zp = None
55+
if group_size is not None:
56+
if zero_points:
57+
assert not quant_type.is_signed() and quant_type.max() > 0
58+
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
59+
maybe_w_zp = (
60+
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
61+
)
62+
else:
63+
# If the bias is such that there are no possible negative/positive
64+
# values, set the max value to inf to avoid divide by 0
65+
w_s = torch.max(
66+
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
67+
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
68+
)
69+
70+
# Quantize
71+
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
72+
w_q = torch.clamp(w_q, min_q_val, max_q_val)
73+
74+
# Compute ref (dequantized)
75+
# For some kernels (namely Machete) the zero-points are applied after the
76+
# scales are applied, for this case computing the reference in similar way
77+
# allows us to use tighter error tolerances in our unit tests.
78+
if ref_zero_points_after_scales and maybe_w_zp is not None:
79+
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
80+
else:
81+
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
82+
83+
if quant_type.has_bias():
84+
w_q += quant_type.bias
85+
86+
# Restore original shapes
87+
if group_size is not None and group_size < size_k:
88+
89+
def reshape_w(w):
90+
w = w.reshape((group_size, -1, size_n))
91+
w = w.permute(1, 0, 2)
92+
w = w.reshape((size_k, size_n)).contiguous()
93+
return w
94+
95+
w_q = reshape_w(w_q)
96+
w_ref = reshape_w(w_ref)
97+
w_s = w_s.reshape((-1, size_n)).contiguous()
98+
99+
if maybe_w_zp is not None:
100+
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
101+
maybe_w_zp = maybe_w_zp.to(device=orig_device)
102+
103+
return (
104+
w_ref.to(device=orig_device),
105+
w_q.to(device=orig_device),
106+
w_s if group_size is not None else None,
107+
maybe_w_zp,
108+
)
109+
110+
18111
def awq_pack(
19112
q_w: torch.Tensor,
20113
num_bits: int,

0 commit comments

Comments
 (0)