Skip to content

Commit cd5f91f

Browse files
authored
[CPU] GQA supports head_sink input for smooth softmax (microsoft#25269)
### Description It is an extension of [Smooth Softmax](microsoft#21867) feature. The difference is that each head has a learnable smooth factor that adding to the denominator of softmax. The smooth factor is like an extra element that joins the softmax. The usage of the smooth factor in softmax is like the following: ```math softmax_{i} = \frac{exp(x_{i})}{exp(s)+ \sum_{j} exp(x_{j})} ``` The head_sink is a float tensor with length of number of attention heads. For h-th head, `head_sink[h]` is used as smooth factor s. When head_sink is not provided, constant 0 is used as smooth factor s. Changes: - [x] Update operator spec to add an optional new input `head_sink` - [x] Implement CPU (MLAS) kernel. - [x] Update test_gqa_cpu.py to test it. CUDA kernel will be updated later in a separate PR.
1 parent b49fc62 commit cd5f91f

File tree

15 files changed

+219
-104
lines changed

15 files changed

+219
-104
lines changed

docs/ContribOperators.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2555,7 +2555,7 @@ This version of the operator has been available since version 1 of the 'com.micr
25552555
<dd>Softcap value for attention weights. Default value is 0.</dd>
25562556
</dl>
25572557

2558-
#### Inputs (7 - 11)
2558+
#### Inputs (7 - 12)
25592559

25602560
<dl>
25612561
<dt><tt>query</tt> : T</dt>
@@ -2580,6 +2580,8 @@ This version of the operator has been available since version 1 of the 'com.micr
25802580
<dd>2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel uses only the first element</dd>
25812581
<dt><tt>attention_bias</tt> (optional) : T</dt>
25822582
<dd>additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)</dd>
2583+
<dt><tt>head_sink</tt> (optional) : T</dt>
2584+
<dd>1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.</dd>
25832585
</dl>
25842586

25852587
#### Outputs

docs/OperatorKernels.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ Do not modify directly.*
538538
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
539539
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
540540
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
541-
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
541+
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
542542
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
543543
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
544544
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
@@ -942,7 +942,7 @@ Do not modify directly.*
942942
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
943943
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
944944
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
945-
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
945+
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
946946
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
947947
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
948948
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -1420,7 +1420,7 @@ Do not modify directly.*
14201420
|FusedMatMulActivation|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
14211421
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
14221422
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**M** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
1423-
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
1423+
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
14241424
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
14251425
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
14261426
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|

onnxruntime/contrib_ops/cpu/bert/attention_helper.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ namespace onnxruntime {
1717
namespace contrib {
1818

1919
template <typename T>
20-
inline void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {
21-
MlasComputeSoftmax(score, score, N, D, false, true, tp);
20+
inline void ComputeSmoothSoftmaxInplace(T* score, int D, float sink, ThreadPool* tp) {
21+
MlasComputeSoftmax(score, score, 1, D, false, true, sink, tp);
2222
}
2323

2424
template <typename T>
2525
inline void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {
26-
MlasComputeSoftmax(score, score, N, D, false, false, tp);
26+
MlasComputeSoftmax(score, score, N, D, false, false, 0.0f, tp);
2727
}
2828

2929
template <typename T>

onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class GQAAttentionBase {
5151
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
5252
const T* K, // K data with shape BxN_kvxSxH
5353
const T* V, // V data with shape BxN_kvxSxH
54+
const T* head_sink, // Head sink for smooth softmax, nullptr if not used
5455
const Tensor* attention_bias, // Attention bias to add to QxK'
5556
const Tensor* past_key, // past K input tensor (if not using past state)
5657
const Tensor* past_value, // past V input tensor (if not using past state)
@@ -97,7 +98,7 @@ class GQAAttentionBase {
9798
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
9899

99100
if (gqa_mlas_supported) {
100-
ComputeAttentionProbs(static_cast<T*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), attention_bias_data,
101+
ComputeAttentionProbs(static_cast<T*>(attention_probs), Q, k, head_sink, seqlens_k->Data<int32_t>(), attention_bias_data,
101102
batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache,
102103
head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt,
103104
tp, allocator);
@@ -110,7 +111,7 @@ class GQAAttentionBase {
110111
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
111112
is_prompt, tp, allocator);
112113
} else {
113-
ComputeAttentionProbs(static_cast<float*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), attention_bias_data,
114+
ComputeAttentionProbs(static_cast<float*>(attention_probs), Q, k, head_sink, seqlens_k->Data<int32_t>(), attention_bias_data,
114115
batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache,
115116
head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt,
116117
tp, allocator);
@@ -136,6 +137,7 @@ class GQAAttentionBase {
136137
void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT
137138
const T* Q, // Q data. Its size is BxNxSxH
138139
const T* K, // k data. Its size is BxNxLxH
140+
const T* head_sink, // for smooth softmax. Its size is N.
139141
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
140142
const T* attention_bias, // optional attention bias
141143
const size_t batch_size, // batch size of self-attention
@@ -310,8 +312,9 @@ class GQAAttentionBase {
310312
}
311313
}
312314

313-
if (use_smooth_softmax_) {
314-
ComputeSmoothSoftmaxInplace(output_softmax + start_offset, 1, static_cast<int>(window_size), nullptr);
315+
if (use_smooth_softmax_ || head_sink != nullptr) {
316+
float sink = (head_sink != nullptr) ? static_cast<float>(head_sink[head_index]) : 0.0f;
317+
ComputeSmoothSoftmaxInplace(output_softmax + start_offset, static_cast<int>(window_size), sink, nullptr);
315318
} else {
316319
ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast<int>(window_size), nullptr);
317320
}

onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,11 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
206206

207207
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
208208

209+
const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data<T>() : nullptr;
210+
209211
// Compute the attention score and apply the score to V
210212
return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(),
211-
attention_bias, past_key, past_value, output, present_k, present_v,
213+
head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v,
212214
seqlens_k, parameters, allocator, context);
213215
}
214216
} // namespace contrib

onnxruntime/core/graph/contrib_ops/bert_defs.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
11841184
"additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)",
11851185
"T",
11861186
OpSchema::Optional)
1187+
.Input(11,
1188+
"head_sink",
1189+
"1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.",
1190+
"T",
1191+
OpSchema::Optional)
11871192
.Output(0,
11881193
"output",
11891194
"3D output tensor with shape (batch_size, sequence_length, hidden_size)",

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,7 @@ MlasComputeSoftmax(
10201020
size_t D,
10211021
bool LogSoftmax,
10221022
bool SmoothSoftmax,
1023+
float Sink,
10231024
MLAS_THREADPOOL* ThreadPool
10241025
);
10251026

onnxruntime/core/mlas/lib/compute.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ struct MLAS_SOFTMAX_WORK_BLOCK {
7474
ptrdiff_t ThreadCountN;
7575
bool LogSoftmax;
7676
bool SmoothSoftmax;
77+
float Sink;
7778
const T* Input;
7879
T* Output;
7980
size_t N;
@@ -850,6 +851,7 @@ Return Value:
850851
const size_t D = WorkBlock->D;
851852
const bool LogSoftmax = WorkBlock->LogSoftmax;
852853
const bool SmoothSoftmax = WorkBlock->SmoothSoftmax;
854+
const float Sink = WorkBlock->Sink;
853855

854856
const float* Input = WorkBlock->Input + n * D;
855857
float* Output = WorkBlock->Output + n * D;
@@ -880,11 +882,12 @@ Return Value:
880882
#else
881883
float Maximum = MlasReduceMaximumF32Kernel(Input, D);
882884
#endif
883-
float NegativeMaximum = -Maximum;
884-
if (SmoothSoftmax && NegativeMaximum > 0.0f) {
885-
NegativeMaximum = 0.0f;
885+
if (SmoothSoftmax && Sink > Maximum) {
886+
Maximum = Sink;
886887
}
887888

889+
float NegativeMaximum = -Maximum;
890+
888891
//
889892
// Compute the exponential function for each element of the row (save to Temp if provided) and
890893
// compute the sum of these exponential functions.
@@ -897,7 +900,7 @@ Return Value:
897900
#endif
898901

899902
if (SmoothSoftmax) {
900-
Accumulation += expf(NegativeMaximum);
903+
Accumulation += expf(Sink + NegativeMaximum);
901904
}
902905

903906
if (LogSoftmax) {
@@ -1014,6 +1017,7 @@ MlasComputeSoftmax(
10141017
size_t D,
10151018
bool LogSoftmax,
10161019
bool SmoothSoftmax,
1020+
float Sink,
10171021
MLAS_THREADPOOL* ThreadPool
10181022
)
10191023
/*++
@@ -1039,6 +1043,8 @@ Routine Description:
10391043
10401044
SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation.
10411045
1046+
Sink - Supplies the smooth factor to use in the softmax operation.
1047+
10421048
ThreadPool - Supplies the thread pool object to use, else nullptr if the
10431049
base library threading support should be used.
10441050
@@ -1060,6 +1066,7 @@ Return Value:
10601066
WorkBlock.Output = Output;
10611067
WorkBlock.N = N;
10621068
WorkBlock.D = D;
1069+
WorkBlock.Sink = Sink;
10631070

10641071
//
10651072
// Compute the number of target threads given the complexity of the softmax
@@ -1097,6 +1104,7 @@ MlasComputeSoftmax<float>(
10971104
size_t D,
10981105
bool LogSoftmax,
10991106
bool SmoothSoftmax,
1107+
float Sink,
11001108
MLAS_THREADPOOL* ThreadPool
11011109
);
11021110

@@ -1110,6 +1118,7 @@ MlasComputeSoftmax<MLAS_FP16>(
11101118
size_t D,
11111119
bool LogSoftmax,
11121120
bool SmoothSoftmax,
1121+
float Sink,
11131122
MLAS_THREADPOOL* ThreadPool
11141123
);
11151124

onnxruntime/core/providers/cpu/math/softmax_shared.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ common::Status SoftmaxCPU<float>(size_t N,
9999
float* Ydata,
100100
bool logarithmic,
101101
onnxruntime::concurrency::ThreadPool* thread_pool) {
102-
MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, thread_pool);
102+
MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, 0.0f, thread_pool);
103103
return Status::OK();
104104
}
105105

onnxruntime/core/providers/cpu/ml/ml_common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ void batched_update_scores_inplace(gsl::span<T> scores, int64_t num_batches_in,
445445
}
446446

447447
if (use_mlas) {
448-
MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow<size_t>(batch_size), false, false, threadpool);
448+
MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow<size_t>(batch_size), false, false, 0.0f, threadpool);
449449
} else {
450450
while (s < s_end) {
451451
gsl::span<float> scores_for_batch(s, s + batch_size);

0 commit comments

Comments
 (0)