Skip to content

Commit e8afb26

Browse files
committed
fix: add quant_utils
1 parent fd2a8cc commit e8afb26

File tree

1 file changed

+166
-0
lines changed

1 file changed

+166
-0
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
3+
4+
from typing import Optional
5+
6+
import numpy
7+
import torch
8+
from sgl_kernel.scalar_type import ScalarType
9+
10+
11+
def get_pack_factor(num_bits):
12+
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
13+
return 32 // num_bits
14+
15+
16+
def pack_cols(
17+
q_w: torch.Tensor,
18+
num_bits: int,
19+
size_k: int,
20+
size_n: int,
21+
):
22+
assert q_w.shape == (size_k, size_n)
23+
24+
pack_factor = get_pack_factor(num_bits)
25+
assert size_n % pack_factor == 0
26+
27+
orig_device = q_w.device
28+
29+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
30+
31+
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
32+
33+
for i in range(pack_factor):
34+
q_res |= q_w[:, i::pack_factor] << num_bits * i
35+
36+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
37+
q_res = q_res.contiguous()
38+
39+
return q_res
40+
41+
42+
def unpack_cols(
43+
packed_q_w: torch.Tensor,
44+
num_bits: int,
45+
size_k: int,
46+
size_n: int,
47+
):
48+
pack_factor = get_pack_factor(num_bits)
49+
assert size_n % pack_factor == 0
50+
assert packed_q_w.shape == (
51+
size_k,
52+
size_n // pack_factor,
53+
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
54+
packed_q_w.shape, size_k, size_n, pack_factor
55+
)
56+
57+
orig_device = packed_q_w.device
58+
59+
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
60+
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
61+
62+
mask = (1 << num_bits) - 1
63+
for i in range(pack_factor):
64+
vals = packed_q_w_cpu & mask
65+
packed_q_w_cpu >>= num_bits
66+
q_res[:, i::pack_factor] = vals
67+
68+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
69+
q_res = q_res.contiguous()
70+
71+
return q_res
72+
73+
74+
def quantize_weights(
75+
w: torch.Tensor,
76+
quant_type: ScalarType,
77+
group_size: Optional[int],
78+
zero_points: bool = False,
79+
ref_zero_points_after_scales: bool = False,
80+
):
81+
assert (
82+
quant_type.is_integer()
83+
), "Floating point quantization may work but has not been tested"
84+
assert not zero_points or group_size is not None, (
85+
"to have group zero points, group_size must be provided "
86+
"(-1 group_size is channelwise)"
87+
)
88+
89+
orig_device = w.device
90+
orig_type = w.dtype
91+
size_k, size_n = w.shape
92+
93+
assert w.is_floating_point(), "w must be float"
94+
95+
if group_size == -1:
96+
group_size = size_k
97+
98+
# Reshape to [groupsize, -1]
99+
if group_size is not None and group_size < size_k:
100+
w = w.reshape((-1, group_size, size_n))
101+
w = w.permute(1, 0, 2)
102+
w = w.reshape((group_size, -1))
103+
104+
# Compute scale for each group
105+
max_val = torch.max(w, 0, keepdim=True).values
106+
min_val = torch.min(w, 0, keepdim=True).values
107+
108+
max_q_val = quant_type.max()
109+
min_q_val = quant_type.min()
110+
111+
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
112+
maybe_w_zp = None
113+
if group_size is not None:
114+
if zero_points:
115+
assert not quant_type.is_signed() and quant_type.max() > 0
116+
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
117+
maybe_w_zp = (
118+
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
119+
)
120+
else:
121+
# If the bias is such that there are no possible negative/positive
122+
# values, set the max value to inf to avoid divide by 0
123+
w_s = torch.max(
124+
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
125+
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
126+
)
127+
128+
# Quantize
129+
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
130+
w_q = torch.clamp(w_q, min_q_val, max_q_val)
131+
132+
# Compute ref (dequantized)
133+
# For some kernels (namely Machete) the zero-points are applied after the
134+
# scales are applied, for this case computing the reference in similar way
135+
# allows us to use tighter error tolerances in our unit tests.
136+
if ref_zero_points_after_scales and maybe_w_zp is not None:
137+
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
138+
else:
139+
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
140+
141+
if quant_type.has_bias():
142+
w_q += quant_type.bias
143+
144+
# Restore original shapes
145+
if group_size is not None and group_size < size_k:
146+
147+
def reshape_w(w):
148+
w = w.reshape((group_size, -1, size_n))
149+
w = w.permute(1, 0, 2)
150+
w = w.reshape((size_k, size_n)).contiguous()
151+
return w
152+
153+
w_q = reshape_w(w_q)
154+
w_ref = reshape_w(w_ref)
155+
w_s = w_s.reshape((-1, size_n)).contiguous()
156+
157+
if maybe_w_zp is not None:
158+
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
159+
maybe_w_zp = maybe_w_zp.to(device=orig_device)
160+
161+
return (
162+
w_ref.to(device=orig_device),
163+
w_q.to(device=orig_device),
164+
w_s if group_size is not None else None,
165+
maybe_w_zp,
166+
)

0 commit comments

Comments
 (0)