Skip to content

Commit aa5b9ea

Browse files
zhanghuanrongPatriceVignola
authored andcommitted
fix error due to () not used on operator priority. (#14699)
1 parent 1af40db commit aa5b9ea

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
146146
const auto BNS = batch_size * num_heads_ * seq_len;
147147
const size_t elements_in_query = (size_t)BNS * (size_t)head_size;
148148
const size_t elements_after_gemm = (size_t)BNS *(size_t)D;
149-
size_t workspace_size = sizeof(T) * (elements_in_query + (seq_len < D) ? elements_after_gemm : (size_t)0);
149+
bool reuse_output = (seq_len >= D);
150+
size_t workspace_size = sizeof(T) * (elements_in_query + (reuse_output ? (size_t)0 : elements_after_gemm));
150151
auto workspace = GetScratchBuffer<void>(workspace_size, context->GetComputeStream());
151152

152153
// format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH)
@@ -161,9 +162,9 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
161162
false, head_size, reinterpret_cast<CudaT*>(static_cast<CudaT*>(nullptr)), total_maxtrix);
162163

163164
// reuse output if possible
164-
CudaT* gemm_output = (seq_len < D) ? (reinterpret_cast<CudaT*>(workspace.get()) + elements_in_query)
165-
: reinterpret_cast<CudaT*>(output->template MutableData<T>());
166-
int ld_gemm_output = max(seq_len, D);
165+
CudaT* gemm_output = reuse_output ? reinterpret_cast<CudaT*>(output->template MutableData<T>())
166+
: (reinterpret_cast<CudaT*>(workspace.get()) + elements_in_query);
167+
int ld_gemm_output = reuse_output ? seq_len : D;
167168

168169
const CudaT one = ToCudaType<T>::FromFloat(1.0f);
169170
const CudaT zero = ToCudaType<T>::FromFloat(0.0f);

0 commit comments

Comments
 (0)