Skip to content

Commit 8e03b64

Browse files
AniZpZhuangtingwei9988yych0745HandH1998弋云
authored
[1/n] apply wna16marlin kernel in moe weight only quantization (#7683)
Co-authored-by: 晟海 <[email protected]> Co-authored-by: yych0745 <[email protected]> Co-authored-by: HandH1998 <[email protected]> Co-authored-by: 弋云 <[email protected]> Co-authored-by: walker-ai <[email protected]>
1 parent b3fa5dc commit 8e03b64

27 files changed

+6104
-1
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
3+
4+
from typing import Optional
5+
6+
import numpy
7+
import torch
8+
from sgl_kernel.scalar_type import ScalarType
9+
10+
11+
def get_pack_factor(num_bits):
12+
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
13+
return 32 // num_bits
14+
15+
16+
def pack_cols(
17+
q_w: torch.Tensor,
18+
num_bits: int,
19+
size_k: int,
20+
size_n: int,
21+
):
22+
assert q_w.shape == (size_k, size_n)
23+
24+
pack_factor = get_pack_factor(num_bits)
25+
assert size_n % pack_factor == 0
26+
27+
orig_device = q_w.device
28+
29+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
30+
31+
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
32+
33+
for i in range(pack_factor):
34+
q_res |= q_w[:, i::pack_factor] << num_bits * i
35+
36+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
37+
q_res = q_res.contiguous()
38+
39+
return q_res
40+
41+
42+
def unpack_cols(
43+
packed_q_w: torch.Tensor,
44+
num_bits: int,
45+
size_k: int,
46+
size_n: int,
47+
):
48+
pack_factor = get_pack_factor(num_bits)
49+
assert size_n % pack_factor == 0
50+
assert packed_q_w.shape == (
51+
size_k,
52+
size_n // pack_factor,
53+
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
54+
packed_q_w.shape, size_k, size_n, pack_factor
55+
)
56+
57+
orig_device = packed_q_w.device
58+
59+
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
60+
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
61+
62+
mask = (1 << num_bits) - 1
63+
for i in range(pack_factor):
64+
vals = packed_q_w_cpu & mask
65+
packed_q_w_cpu >>= num_bits
66+
q_res[:, i::pack_factor] = vals
67+
68+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
69+
q_res = q_res.contiguous()
70+
71+
return q_res
72+
73+
74+
def quantize_weights(
75+
w: torch.Tensor,
76+
quant_type: ScalarType,
77+
group_size: Optional[int],
78+
zero_points: bool = False,
79+
ref_zero_points_after_scales: bool = False,
80+
):
81+
assert (
82+
quant_type.is_integer()
83+
), "Floating point quantization may work but has not been tested"
84+
assert not zero_points or group_size is not None, (
85+
"to have group zero points, group_size must be provided "
86+
"(-1 group_size is channelwise)"
87+
)
88+
89+
orig_device = w.device
90+
orig_type = w.dtype
91+
size_k, size_n = w.shape
92+
93+
assert w.is_floating_point(), "w must be float"
94+
95+
if group_size == -1:
96+
group_size = size_k
97+
98+
# Reshape to [groupsize, -1]
99+
if group_size is not None and group_size < size_k:
100+
w = w.reshape((-1, group_size, size_n))
101+
w = w.permute(1, 0, 2)
102+
w = w.reshape((group_size, -1))
103+
104+
# Compute scale for each group
105+
max_val = torch.max(w, 0, keepdim=True).values
106+
min_val = torch.min(w, 0, keepdim=True).values
107+
108+
max_q_val = quant_type.max()
109+
min_q_val = quant_type.min()
110+
111+
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
112+
maybe_w_zp = None
113+
if group_size is not None:
114+
if zero_points:
115+
assert not quant_type.is_signed() and quant_type.max() > 0
116+
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
117+
maybe_w_zp = (
118+
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
119+
)
120+
else:
121+
# If the bias is such that there are no possible negative/positive
122+
# values, set the max value to inf to avoid divide by 0
123+
w_s = torch.max(
124+
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
125+
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
126+
)
127+
128+
# Quantize
129+
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
130+
w_q = torch.clamp(w_q, min_q_val, max_q_val)
131+
132+
# Compute ref (dequantized)
133+
# For some kernels (namely Machete) the zero-points are applied after the
134+
# scales are applied, for this case computing the reference in similar way
135+
# allows us to use tighter error tolerances in our unit tests.
136+
if ref_zero_points_after_scales and maybe_w_zp is not None:
137+
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
138+
else:
139+
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
140+
141+
if quant_type.has_bias():
142+
w_q += quant_type.bias
143+
144+
# Restore original shapes
145+
if group_size is not None and group_size < size_k:
146+
147+
def reshape_w(w):
148+
w = w.reshape((group_size, -1, size_n))
149+
w = w.permute(1, 0, 2)
150+
w = w.reshape((size_k, size_n)).contiguous()
151+
return w
152+
153+
w_q = reshape_w(w_q)
154+
w_ref = reshape_w(w_ref)
155+
w_s = w_s.reshape((-1, size_n)).contiguous()
156+
157+
if maybe_w_zp is not None:
158+
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
159+
maybe_w_zp = maybe_w_zp.to(device=orig_device)
160+
161+
return (
162+
w_ref.to(device=orig_device),
163+
w_q.to(device=orig_device),
164+
w_s if group_size is not None else None,
165+
maybe_w_zp,
166+
)

sgl-kernel/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,15 @@ set(SOURCES
250250
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
251251
"csrc/kvcacheio/transfer.cu"
252252
"csrc/common_extension.cc"
253+
"csrc/moe/marlin_moe_wna16/ops.cu"
254+
"csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu"
255+
"csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu"
256+
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
257+
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
258+
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"
259+
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu"
260+
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu"
261+
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu"
253262
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
254263
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
255264
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"

sgl-kernel/csrc/common_extension.cc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ limitations under the License.
1717
#include <torch/library.h>
1818

1919
#include "sgl_kernel_ops.h"
20-
2120
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
2221
/*
2322
* From csrc/allreduce
@@ -209,6 +208,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
209208
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
210209
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
211210

211+
m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor");
212+
m.impl("gptq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::gptq_marlin_repack);
213+
214+
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
215+
m.impl("awq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::awq_marlin_repack);
216+
212217
/*
213218
* From csrc/speculative
214219
*/
@@ -303,6 +308,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
303308
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
304309
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
305310
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
311+
m.def(
312+
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
313+
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
314+
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
315+
"Tensor sorted_token_ids,"
316+
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
317+
"Tensor! topk_weights, int moe_block_size, int top_k, "
318+
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
319+
"int size_m, int size_n, int size_k,"
320+
"bool is_full_k, bool use_atomic_add,"
321+
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
322+
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
306323

307324
/*
308325
* From Sparse Flash Attention

0 commit comments

Comments
 (0)