File tree Expand file tree Collapse file tree 2 files changed +6
-6
lines changed
python/sglang/srt/layers/quantization Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -160,8 +160,8 @@ def _per_token_group_quant_fp8_colmajor(
160
160
"""
161
161
# Map the program id to the row of X and Y it should compute.
162
162
g_id = tl .program_id (0 )
163
- y_ptr += g_id * group_size
164
- y_q_ptr += g_id * group_size
163
+ y_ptr += g_id . to ( tl . int64 ) * group_size
164
+ y_q_ptr += g_id . to ( tl . int64 ) * group_size
165
165
166
166
# Convert g_id the flattened block coordinate to 2D so we can index
167
167
# into the output y_scales matrix
Original file line number Diff line number Diff line change @@ -35,12 +35,12 @@ __global__ void per_token_group_quant_8bit_kernel(
35
35
const int scale_num_rows = 0 ,
36
36
const int scale_stride = 0 ) {
37
37
const int threads_per_group = 16 ;
38
- const int local_group_id = threadIdx .x / threads_per_group;
38
+ const int64_t local_group_id = threadIdx .x / threads_per_group;
39
39
const int lane_id = threadIdx .x % threads_per_group;
40
40
41
- const int block_group_id = blockIdx .x * groups_per_block;
42
- const int global_group_id = block_group_id + local_group_id;
43
- const int block_group_offset = global_group_id * group_size;
41
+ const int64_t block_group_id = blockIdx .x * groups_per_block;
42
+ const int64_t global_group_id = block_group_id + local_group_id;
43
+ const int64_t block_group_offset = global_group_id * group_size;
44
44
45
45
float local_absmax = eps;
46
46
You can’t perform that action at this time.
0 commit comments