Skip to content

Commit 4a0d191

Browse files
Fix bug of deepseek-v3 under DP+EP mode with large batchsize/seqlen (#6449)
1 parent 5748241 commit 4a0d191

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

python/sglang/srt/layers/quantization/fp8_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def _per_token_group_quant_fp8_colmajor(
160160
"""
161161
# Map the program id to the row of X and Y it should compute.
162162
g_id = tl.program_id(0)
163-
y_ptr += g_id * group_size
164-
y_q_ptr += g_id * group_size
163+
y_ptr += g_id.to(tl.int64) * group_size
164+
y_q_ptr += g_id.to(tl.int64) * group_size
165165

166166
# Convert g_id the flattened block coordinate to 2D so we can index
167167
# into the output y_scales matrix

sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ __global__ void per_token_group_quant_8bit_kernel(
3535
const int scale_num_rows = 0,
3636
const int scale_stride = 0) {
3737
const int threads_per_group = 16;
38-
const int local_group_id = threadIdx.x / threads_per_group;
38+
const int64_t local_group_id = threadIdx.x / threads_per_group;
3939
const int lane_id = threadIdx.x % threads_per_group;
4040

41-
const int block_group_id = blockIdx.x * groups_per_block;
42-
const int global_group_id = block_group_id + local_group_id;
43-
const int block_group_offset = global_group_id * group_size;
41+
const int64_t block_group_id = blockIdx.x * groups_per_block;
42+
const int64_t global_group_id = block_group_id + local_group_id;
43+
const int64_t block_group_offset = global_group_id * group_size;
4444

4545
float local_absmax = eps;
4646

0 commit comments

Comments
 (0)