Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions python/sglang/srt/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,20 @@
except ImportError:
ops = None

from sglang.srt.utils import is_cuda
from sglang.srt.utils import is_cuda, is_hip

_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
from sgl_kernel import awq_dequantize, fused_marlin_moe
elif _is_hip:
from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_triton as awq_dequantize,
)

warnings.warn(f"HIP does not support fused_marlin_moe currently.")
else:
warnings.warn(f"Only CUDA and HIP support AWQ currently.")

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -398,7 +407,6 @@ def apply(
pack_factor = self.quant_config.pack_factor
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
reshaped_x = x.reshape(-1, x.shape[-1])

out = awq_dequantize(qweight, scales, qzeros)
out = torch.matmul(reshaped_x, out)

Expand Down
339 changes: 339 additions & 0 deletions python/sglang/srt/layers/quantization/awq_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/awq_triton.py

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import triton
import triton.language as tl

AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]


@triton.jit
def awq_dequantize_kernel(
qweight_ptr, # quantized matrix
scales_ptr, # scales, per group
zeros_ptr, # zeros, per group
group_size, # Should always be one of the supported group sizes
result_ptr, # Output matrix
num_cols, # input num cols in qweight
num_rows, # input num rows in qweight
BLOCK_SIZE_X: tl.constexpr,
BLOCK_SIZE_Y: tl.constexpr,
):
# Setup the pids.
pid_x = tl.program_id(axis=0)
pid_y = tl.program_id(axis=1)

# Compute offsets and masks for qweight_ptr.
offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]

masks_y = offsets_y < num_rows
masks_x = offsets_x < num_cols

masks = masks_y[:, None] & masks_x[None, :]

# Compute offsets and masks for result output ptr.
result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8)
result_offsets = (
8 * num_cols * result_offsets_y[:, None] + result_offsets_x[None, :]
)

result_masks_y = result_offsets_y < num_rows
result_masks_x = result_offsets_x < num_cols * 8
result_masks = result_masks_y[:, None] & result_masks_x[None, :]

# Load the weights.
iweights = tl.load(qweight_ptr + offsets, masks, 0.0)
iweights = tl.interleave(iweights, iweights)
iweights = tl.interleave(iweights, iweights)
iweights = tl.interleave(iweights, iweights)

# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
reverse_awq_order_tensor = (
(tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]
).reshape(8)

# Use this to compute a set of shifts that can be used to unpack and
# reorder the values in iweights and zeros.
shifts = reverse_awq_order_tensor * 4
shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8))
shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))

# Unpack and reorder: shift out the correct 4-bit value and mask.
iweights = (iweights >> shifts) & 0xF

# Compute zero offsets and masks.
zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]

zero_masks_y = zero_offsets_y < num_rows // group_size
zero_masks_x = zero_offsets_x < num_cols
zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]

# Load the zeros.
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))

# Unpack and reorder: shift out the correct 4-bit value and mask.
zeros = (zeros >> shifts) & 0xF

# Compute scale offsets and masks.
scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
scale_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8)
scale_offsets = num_cols * 8 * scale_offsets_y[:, None] + scale_offsets_x[None, :]
scale_masks_y = scale_offsets_y < num_rows // group_size
scale_masks_x = scale_offsets_x < num_cols * 8
scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]

# Load the scales.
scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0)
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))

# Dequantize.
iweights = (iweights - zeros) * scales
iweights = iweights.to(result_ptr.type.element_ty)

# Finally, store.
tl.store(result_ptr + result_offsets, iweights, result_masks)


@triton.jit
def awq_gemm_kernel(
a_ptr,
b_ptr,
c_ptr,
zeros_ptr,
scales_ptr,
M,
N,
K,
group_size,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
SPLIT_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
pid_z = tl.program_id(1)

# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
# num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

pid_m = pid // num_pid_n
pid_n = pid % num_pid_n

accumulator_dtype = c_ptr.type.element_ty

# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
# accumulator = tl.arange(0, BLOCK_SIZE_N)
# accumulator = tl.broadcast_to(accumulator[None, :],
# (BLOCK_SIZE_M, BLOCK_SIZE_N))
# accumulator = accumulator & 0x0
# accumulator = accumulator.to(accumulator_dtype)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)

# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
reverse_awq_order_tensor = (
(tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]
).reshape(8)

# Create the necessary shifts to use to unpack.
shifts = reverse_awq_order_tensor * 4
shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8))
shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N))

# Offsets and masks.
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
masks_am = offsets_am < M

offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
masks_bn = offsets_bn < N // 8

offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
masks_zn = offsets_zn < N // 8

offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
masks_sn = offsets_sn < N

offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offsets_a = K * offsets_am[:, None] + offsets_k[None, :]
offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :]

a_ptrs = a_ptr + offsets_a
b_ptrs = b_ptr + offsets_b

# NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv
# block_offset = BLOCK_SIZE_K * SPLIT_K
# for k in range(0, (K + block_offset - 1) // (block_offset)):
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
masks_k = offsets_k < K
masks_a = masks_am[:, None] & masks_k[None, :]
a = tl.load(a_ptrs, mask=masks_a, other=0.0)

masks_b = masks_k[:, None] & masks_bn[None, :]
b = tl.load(b_ptrs, mask=masks_b, other=0.0)
b = tl.interleave(b, b)
b = tl.interleave(b, b)
b = tl.interleave(b, b)

# Dequantize b.
offsets_szk = (
BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K
) // group_size + tl.arange(0, 1)
offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
masks_zk = offsets_szk < K // group_size
masks_z = masks_zk[:, None] & masks_zn[None, :]
zeros_ptrs = zeros_ptr + offsets_z
zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))

offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
masks_sk = offsets_szk < K // group_size
masks_s = masks_sk[:, None] & masks_sn[None, :]
scales_ptrs = scales_ptr + offsets_s
scales = tl.load(scales_ptrs, mask=masks_s, other=0.0)
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))

b = (b >> shifts) & 0xF
zeros = (zeros >> shifts) & 0xF
b = (b - zeros) * scales
b = b.to(c_ptr.type.element_ty)

# Accumulate results.
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)

offsets_k += BLOCK_SIZE_K * SPLIT_K
a_ptrs += BLOCK_SIZE_K * SPLIT_K
b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8)

c = accumulator.to(c_ptr.type.element_ty)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


# qweights - [K , M // 8], int32
# scales - [K // G, M ], float16
# zeros - [K // G, M // 8], int32
def awq_dequantize_triton(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
block_size_x: int = 32,
block_size_y: int = 32,
) -> torch.Tensor:
K = qweight.shape[0]
M = scales.shape[1]
group_size = qweight.shape[0] // scales.shape[0]

assert K > 0 and M > 0
assert scales.shape[0] == K // group_size and scales.shape[1] == M
assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8
assert group_size <= K
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K

# Result tensor:
# number of rows = same as input tensor
# number of cols = 8 x input tensor num cols
result = torch.empty(
qweight.shape[0],
qweight.shape[1] * 8,
device=qweight.device,
dtype=scales.dtype,
)

Y = qweight.shape[0] # num rows
X = qweight.shape[1] # num cols

grid = lambda META: (
triton.cdiv(X, META["BLOCK_SIZE_X"]),
triton.cdiv(Y, META["BLOCK_SIZE_Y"]),
)
awq_dequantize_kernel[grid](
qweight,
scales,
zeros,
group_size,
result,
X,
Y,
BLOCK_SIZE_X=block_size_x,
BLOCK_SIZE_Y=block_size_y,
)

return result


# input - [M, K]
# qweight - [K, N // 8]
# qzeros - [K // G, N // 8]
# scales - [K // G, N]
# split_k_iters - parallelism along K-dimension, int, power of 2.
def awq_gemm_triton(
input: torch.Tensor,
qweight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: int,
block_size_m: int = 32,
block_size_n: int = 32,
block_size_k: int = 32,
) -> torch.Tensor:
M, K = input.shape
N = qweight.shape[1] * 8
group_size = qweight.shape[0] // qzeros.shape[0]

assert N > 0 and K > 0 and M > 0
assert qweight.shape[0] == K and qweight.shape[1] == N // 8
assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8
assert scales.shape[0] == K // group_size and scales.shape[1] == N
assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0
assert split_k_iters <= 32
assert group_size <= K
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K

grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
split_k_iters,
)

result = torch.zeros((split_k_iters, M, N), dtype=scales.dtype, device=input.device)

# A = input, B = qweight, C = result
# A = M x K, B = K x N, C = M x N
awq_gemm_kernel[grid](
input,
qweight,
result,
qzeros,
scales,
M,
N,
K,
group_size,
BLOCK_SIZE_M=block_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
SPLIT_K=split_k_iters,
)

result = result.sum(0)

return result
6 changes: 5 additions & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@
)
elif _is_cpu and _is_cpu_amx_available:
pass
elif _is_hip:
from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_triton as awq_dequantize,
)
else:
from vllm._custom_ops import awq_dequantize

Expand Down Expand Up @@ -2176,7 +2180,7 @@ def post_load_weights(self, is_nextn=False, weight_names=None):
)
if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible
if _is_cuda:
if _is_cuda or _is_hip:
w = awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class TestFile:
# TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701
TestFile("test_reasoning_parser.py", 5),
TestFile("test_rope_rocm.py", 3),
TestFile("test_awq_dequant.py", 2),
],
"per-commit-npu": [
TestFile("test_ascend_attention_backend.py", 400),
Expand Down
Loading
Loading