Skip to content

Commit 8af9f58

Browse files
Fix Local Attention off by 1 bug (#25927)
### Description Previously, local window size of GQA op excluded the current token. This does not match standard HuggingFace implementations where tokens are appended and then local masking occurs; the mismatch can cause the mask to be off by 1 during generation, leading to accuracy issues. This PR corrects this mismatch by including the current token. In practice, this effectively decreases GQA window size by 1. ### Motivation and Context This helps align our models with HuggingFace models. --------- Co-authored-by: Kunal Vaishnavi <[email protected]>
1 parent 978bfca commit 8af9f58

File tree

9 files changed

+13
-16
lines changed

9 files changed

+13
-16
lines changed

onnxruntime/contrib_ops/cpu/bert/attention_parameters.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ struct GroupQueryAttentionParameters : AttentionParameters {
8686
int kv_hidden_size; // hidden size of key or value
8787
int seqlen_past_kv_cache; // sequence length of past kv tensor
8888
int seqlen_present_kv_cache; // sequence length of present kv tensor
89-
int local_window_size; // The window size excludes current token. It only includes tokens on the left side.
89+
int local_window_size; // Mask out tokens prior to total_sequence_length - local_window_size
9090
bool kv_share_buffer;
9191
bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1
9292
bool is_first_prompt; // indicates whether this is first decoding step
@@ -106,7 +106,7 @@ struct PagedAttentionParameters : AttentionParameters {
106106
int block_size; // block size for kv cache
107107
int max_num_blocks_per_seq; // max number of blocks per sequence for kv cache
108108
int num_blocks; // number of blocks in kv cache
109-
int local_window_size; // The window size excludes current token. It only includes tokens on the left side.
109+
int local_window_size; // The window size includes new token. It only includes tokens on the left side.
110110
bool rotary_interleaved;
111111
float softcap;
112112
};

onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,16 +297,15 @@ class GQAAttentionBase {
297297
for (size_t seq = 0; seq < sequence_length; seq++) {
298298
size_t seq_causal_length = past_seqlen + seq + 1;
299299

300-
// local_window_size does not include the current query token, while window_size includes it.
301300
const bool should_apply_local_window = local_window_size_ >= 0 &&
302-
seq_causal_length > static_cast<size_t>(local_window_size_) + 1;
301+
seq_causal_length > static_cast<size_t>(local_window_size_);
303302

304-
const size_t start_offset = should_apply_local_window ? seq_causal_length - local_window_size_ - 1 : 0;
305-
const size_t window_size = should_apply_local_window ? local_window_size_ + 1 : seq_causal_length;
303+
const size_t start_offset = should_apply_local_window ? seq_causal_length - local_window_size_ : 0;
304+
const size_t window_size = should_apply_local_window ? local_window_size_ : seq_causal_length;
306305

307306
// Mask everything before local window, if local window should be applied
308307
if (should_apply_local_window) {
309-
for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) {
308+
for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_; total_seq_id++) {
310309
if constexpr (std::is_same<U, float>::value) {
311310
output_softmax[total_seq_id] = 0.f;
312311
} else {

onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,7 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
223223
}
224224

225225
p.use_smooth_softmax = params.use_smooth_softmax;
226-
227-
// local_windows_size in GQA does not include current query token, while windows_size in this kernel includes it.
228-
p.window_size = params.local_window_size + 1;
226+
p.window_size = params.local_window_size;
229227
}
230228

231229
auto kernel_fn = attention_kernel_batched_impl<Attention>;

onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ Status FlashAttention(
476476
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
477477
scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
478478
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
479-
parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv));
479+
parameters.local_window_size - 1, parameters.rotary_interleaved, parameters.is_packed_qkv));
480480

481481
// if (parameters.left_padding && parameters.is_first_prompt) {
482482
// ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock));

onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ Status FlashAttention(
326326
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_varlen_fwd(
327327
device_prop, stream, q, key_cache, value_cache, output, cumulative_seqlens_q, cumulative_seqlens_kv,
328328
/*seqused_k*/ nullptr, block_table, softmax_lse, batch_size, num_heads, kv_num_heads, head_size, max_query_len,
329-
max_seq_len, token_count, scale, softcap, /*is_causal*/ true, is_bf16, local_window_size, max_num_blocks_per_seq,
329+
max_seq_len, token_count, scale, softcap, /*is_causal*/ true, is_bf16, local_window_size - 1, max_num_blocks_per_seq,
330330
block_size));
331331

332332
DUMP_TENSOR_INIT();

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
250250
if (has_sliding_window) {
251251
// Sliding window
252252
shader.MainFunctionBody()
253-
<< "let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > uniforms.local_window_size + 1;\n"
253+
<< "let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > uniforms.local_window_size;\n"
254254
<< "let start_offset = select(0, seq_causal_length - uniforms.local_window_size, should_apply_local_window);\n"
255255
<< "let effective_seq_length = select(seq_causal_length, uniforms.local_window_size, should_apply_local_window);\n";
256256
} else {

onnxruntime/test/python/transformers/test_gqa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def construct_local_mask(seqlen_q, seqlen_k, window_size, query_padding_mask, ke
582582
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
583583
return torch.logical_or(
584584
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
585-
col_idx < row_idx + sk - sq - window_size[0],
585+
col_idx <= row_idx + sk - sq - window_size[0],
586586
)
587587

588588

onnxruntime/test/python/transformers/test_gqa_cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,7 @@ def construct_local_mask(
11221122
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
11231123
return torch.logical_or(
11241124
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
1125-
col_idx < row_idx + sk - sq - window_size[0],
1125+
col_idx <= row_idx + sk - sq - window_size[0],
11261126
)
11271127

11281128

onnxruntime/test/python/transformers/test_paged_attention_cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def construct_local_mask(
331331
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
332332
return torch.logical_or(
333333
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
334-
col_idx < row_idx + sk - sq - window_size[0],
334+
col_idx <= row_idx + sk - sq - window_size[0],
335335
)
336336

337337

0 commit comments

Comments
 (0)