|
1 | 1 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
2 | 2 |
|
3 | 3 | from types import MappingProxyType
|
4 |
| -from typing import List, Mapping, Optional, Tuple, Union |
| 4 | +from typing import List, Mapping, Tuple, Union |
5 | 5 |
|
6 | 6 | import numpy
|
7 | 7 | import torch
|
@@ -210,99 +210,3 @@ def unpack_cols(
|
210 | 210 | q_res = q_res.contiguous()
|
211 | 211 |
|
212 | 212 | return q_res
|
213 |
| - |
214 |
| - |
215 |
| -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py |
216 |
| -def quantize_weights( |
217 |
| - w: torch.Tensor, |
218 |
| - quant_type: ScalarType, |
219 |
| - group_size: Optional[int], |
220 |
| - zero_points: bool = False, |
221 |
| - ref_zero_points_after_scales: bool = False, |
222 |
| -): |
223 |
| - assert ( |
224 |
| - quant_type.is_integer() |
225 |
| - ), "Floating point quantization may work but has not been tested" |
226 |
| - assert not zero_points or group_size is not None, ( |
227 |
| - "to have group zero points, group_size must be provided " |
228 |
| - "(-1 group_size is channelwise)" |
229 |
| - ) |
230 |
| - |
231 |
| - orig_device = w.device |
232 |
| - orig_type = w.dtype |
233 |
| - size_k, size_n = w.shape |
234 |
| - |
235 |
| - assert w.is_floating_point(), "w must be float" |
236 |
| - |
237 |
| - if group_size == -1: |
238 |
| - group_size = size_k |
239 |
| - |
240 |
| - # Reshape to [groupsize, -1] |
241 |
| - if group_size is not None and group_size < size_k: |
242 |
| - w = w.reshape((-1, group_size, size_n)) |
243 |
| - w = w.permute(1, 0, 2) |
244 |
| - w = w.reshape((group_size, -1)) |
245 |
| - |
246 |
| - # Compute scale for each group |
247 |
| - max_val = torch.max(w, 0, keepdim=True).values |
248 |
| - min_val = torch.min(w, 0, keepdim=True).values |
249 |
| - |
250 |
| - max_q_val = quant_type.max() |
251 |
| - min_q_val = quant_type.min() |
252 |
| - |
253 |
| - w_s = torch.Tensor([1.0]).to(w.device) # unscaled case |
254 |
| - maybe_w_zp = None |
255 |
| - if group_size is not None: |
256 |
| - if zero_points: |
257 |
| - assert not quant_type.is_signed() and quant_type.max() > 0 |
258 |
| - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() |
259 |
| - maybe_w_zp = ( |
260 |
| - torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() |
261 |
| - ) |
262 |
| - else: |
263 |
| - # If the bias is such that there are no possible negative/positive |
264 |
| - # values, set the max value to inf to avoid divide by 0 |
265 |
| - w_s = torch.max( |
266 |
| - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), |
267 |
| - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), |
268 |
| - ) |
269 |
| - |
270 |
| - # Quantize |
271 |
| - w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) |
272 |
| - w_q = torch.clamp(w_q, min_q_val, max_q_val) |
273 |
| - |
274 |
| - # Compute ref (dequantized) |
275 |
| - # For some kernels (namely Machete) the zero-points are applied after the |
276 |
| - # scales are applied, for this case computing the reference in similar way |
277 |
| - # allows us to use tighter error tolerances in our unit tests. |
278 |
| - if ref_zero_points_after_scales and maybe_w_zp is not None: |
279 |
| - w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s |
280 |
| - else: |
281 |
| - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s |
282 |
| - |
283 |
| - if quant_type.has_bias(): |
284 |
| - w_q += quant_type.bias |
285 |
| - |
286 |
| - # Restore original shapes |
287 |
| - if group_size is not None and group_size < size_k: |
288 |
| - |
289 |
| - def reshape_w(w): |
290 |
| - w = w.reshape((group_size, -1, size_n)) |
291 |
| - w = w.permute(1, 0, 2) |
292 |
| - w = w.reshape((size_k, size_n)).contiguous() |
293 |
| - return w |
294 |
| - |
295 |
| - w_q = reshape_w(w_q) |
296 |
| - w_ref = reshape_w(w_ref) |
297 |
| - w_s = w_s.reshape((-1, size_n)).contiguous() |
298 |
| - |
299 |
| - if maybe_w_zp is not None: |
300 |
| - maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() |
301 |
| - maybe_w_zp = maybe_w_zp.to(device=orig_device) |
302 |
| - |
303 |
| - return ( |
304 |
| - w_ref.to(device=orig_device), |
305 |
| - w_q.to(device=orig_device), |
306 |
| - w_s if group_size is not None else None, |
307 |
| - maybe_w_zp, |
308 |
| - ) |
0 commit comments