Skip to content

Commit e6d65fb

Browse files
authored
vulkan: support arbitrary KV dimension in flash attention (#16160)
The "Clamp" spec constant is already based on whether KV is a multiple of Bc, so use that to control whether bounds checking is performed. Add bounds checking to the scalar and coopmat1 paths. Coopmat2 didn't need any changes (the K/V tensors are already optionally clamped, nothing else needed to be changed).
1 parent 8656f5d commit e6d65fb

File tree

3 files changed

+38
-9
lines changed

3 files changed

+38
-9
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ void main() {
117117

118118

119119
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
120+
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
121+
continue;
122+
}
120123
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
121124
#if BLOCK_SIZE > 1
122125
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
@@ -155,7 +158,11 @@ void main() {
155158
uint32_t c = (idx + tid) % Bc;
156159
uint32_t r = (idx + tid) / Bc;
157160
if (idx + tid < Bc * Br) {
158-
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
161+
if (!KV_bounds_check || j * Bc + c < KV) {
162+
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
163+
} else {
164+
masksh[c][r] = float(0);
165+
}
159166
}
160167
}
161168
barrier();
@@ -172,8 +179,11 @@ void main() {
172179

173180
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
174181
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
175-
rowmaxf[r] = Sf[r][0];
182+
rowmaxf[r] = NEG_FLT_MAX_OVER_2;
176183
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
184+
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
185+
continue;
186+
}
177187
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
178188
}
179189
Moldf[r] = Mf[r];
@@ -190,6 +200,9 @@ void main() {
190200
// Compute sum across row of P
191201
rowsumf[r] = 0.0;
192202
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
203+
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
204+
continue;
205+
}
193206
rowsumf[r] += Pf[r][c];
194207
}
195208

@@ -203,6 +216,9 @@ void main() {
203216
}
204217

205218
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
219+
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
220+
continue;
221+
}
206222
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
207223
#if BLOCK_SIZE > 1
208224
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ layout (constant_id = 6) const uint32_t D_split = 16;
1313
const uint32_t HSK_pad = (HSK + 15) & ~15;
1414
const uint32_t HSV_pad = (HSV + 15) & ~15;
1515

16+
const bool KV_bounds_check = Clamp != 0;
17+
1618
layout (push_constant) uniform parameter {
1719
uint32_t N;
1820
uint32_t KV;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,17 @@ void main() {
152152
uint32_t d = (idx + tid) % (HSK / 4);
153153
uint32_t c = (idx + tid) / (HSK / 4);
154154
if (c < Bc && d < HSK / 4) {
155+
f16vec4 K_Tf = f16vec4(0);
156+
if (!KV_bounds_check || j * Bc + c < KV) {
155157
#if BLOCK_SIZE > 1
156-
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
157-
uint ib = coord / BLOCK_SIZE;
158-
uint iqs = (coord % BLOCK_SIZE);
159-
f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
158+
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
159+
uint ib = coord / BLOCK_SIZE;
160+
uint iqs = (coord % BLOCK_SIZE);
161+
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
160162
#else
161-
f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
163+
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
162164
#endif
165+
}
163166

164167
ksh[c * kshstride + d] = K_Tf;
165168
}
@@ -202,16 +205,21 @@ void main() {
202205
uint32_t c = (idx + tid) % Bc;
203206
uint32_t r = (idx + tid) / Bc;
204207
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
205-
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
208+
if (!KV_bounds_check || j * Bc + c < KV) {
209+
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
210+
}
206211
}
207212
}
208213
barrier();
209214
}
210215

211216
float eMf[rows_per_thread];
212217
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
213-
float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
218+
float rowmaxf = NEG_FLT_MAX_OVER_2;
214219
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
220+
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
221+
continue;
222+
}
215223
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
216224
}
217225
float Moldf = Mf[r];
@@ -233,6 +241,9 @@ void main() {
233241
}
234242

235243
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
244+
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
245+
continue;
246+
}
236247
float Pf[rows_per_thread];
237248
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
238249
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);

0 commit comments

Comments
 (0)