@@ -146,7 +146,8 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
146
146
const auto BNS = batch_size * num_heads_ * seq_len;
147
147
const size_t elements_in_query = (size_t )BNS * (size_t )head_size;
148
148
const size_t elements_after_gemm = (size_t )BNS *(size_t )D;
149
- size_t workspace_size = sizeof (T) * (elements_in_query + (seq_len < D) ? elements_after_gemm : (size_t )0 );
149
+ bool reuse_output = (seq_len >= D);
150
+ size_t workspace_size = sizeof (T) * (elements_in_query + (reuse_output ? (size_t )0 : elements_after_gemm));
150
151
auto workspace = GetScratchBuffer<void >(workspace_size, context->GetComputeStream ());
151
152
152
153
// format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH)
@@ -161,9 +162,9 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
161
162
false , head_size, reinterpret_cast <CudaT*>(static_cast <CudaT*>(nullptr )), total_maxtrix);
162
163
163
164
// reuse output if possible
164
- CudaT* gemm_output = (seq_len < D) ? ( reinterpret_cast <CudaT*>(workspace. get ()) + elements_in_query )
165
- : reinterpret_cast <CudaT*>(output-> template MutableData <T>() );
166
- int ld_gemm_output = max ( seq_len, D) ;
165
+ CudaT* gemm_output = reuse_output ? reinterpret_cast <CudaT*>(output-> template MutableData <T>() )
166
+ : ( reinterpret_cast <CudaT*>(workspace. get ()) + elements_in_query );
167
+ int ld_gemm_output = reuse_output ? seq_len : D ;
167
168
168
169
const CudaT one = ToCudaType<T>::FromFloat (1 .0f );
169
170
const CudaT zero = ToCudaType<T>::FromFloat (0 .0f );
0 commit comments