Skip to content

Commit 1211c24

Browse files
hariharans29fuhengwu2021
authored andcommitted
Extend vocab padding for logits MatMul for fp16 GPT2 GreedySearch (microsoft#13842)
1 parent 525205d commit 1211c24

File tree

7 files changed

+105
-6
lines changed

7 files changed

+105
-6
lines changed

docs/ContribOperators.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,6 +1708,8 @@ This version of the operator has been available since version 1 of the 'com.micr
17081708
<dd>no repeat ngrams size</dd>
17091709
<dt><tt>pad_token_id</tt> : int (required)</dt>
17101710
<dd>The id of the padding token</dd>
1711+
<dt><tt>vocab_size</tt> : int</dt>
1712+
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
17111713
</dl>
17121714

17131715
#### Inputs (2 - 7)

onnxruntime/contrib_ops/cpu/transformers/greedy_search_parameters.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ void GreedySearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
1414
pad_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("pad_token_id", -1));
1515
decoder_start_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("decoder_start_token_id", -1));
1616
no_repeat_ngram_size = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_repeat_ngram_size", 0));
17+
vocab_size = static_cast<int>(info.GetAttrOrDefault<int64_t>("vocab_size", -1));
1718
}
1819

1920
void GreedySearchParameters::ParseFromInputs(OpKernelContext* context) {

onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ Status ProcessLogits(const OrtValue& logits, //
244244
// NOTE: `padded_vocab_size` MAY be different from `vocab_size`.
245245
// But the following implementation should work correctly if they are the same
246246
// or different.
247-
int padded_vocab_size = static_cast<int>(logits_shape[2]);
247+
auto padded_vocab_size = static_cast<int>(logits_shape[2]);
248248

249249
cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream);
250250

@@ -475,12 +475,17 @@ Status GreedySearchProcessLogits(
475475
typedef typename ToCudaType<T>::MappedType CudaT;
476476
const CudaT* logits_data = reinterpret_cast<const CudaT*>(logits.Get<Tensor>().Data<T>());
477477

478-
// Logits has shape (batch_size, input_length, vocab_size),
478+
// Logits has shape (batch_size, input_length, padded_vocab_size),
479479
// where input_length equals to parameters_->sequence_length for first subgraph call, and 1 for the remaining calls.
480480
const TensorShape& logits_shape = logits.Get<Tensor>().Shape();
481481
ORT_ENFORCE(logits_shape.NumDimensions() == 3);
482482
auto input_length = logits_shape[1];
483483

484+
// NOTE: `padded_vocab_size` MAY be different from `vocab_size`.
485+
// But the following implementation should work correctly if they are the same
486+
// or different.
487+
auto padded_vocab_size = static_cast<int>(logits_shape[2]);
488+
484489
cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream);
485490

486491
// Get logits for the last token:
@@ -489,13 +494,18 @@ Status GreedySearchProcessLogits(
489494
gsl::span<T>& next_token_scores = greedy_state->next_token_scores;
490495

491496
// TODO(tianleiwu): use one kernel to replace a loop of memory copy.
492-
const CudaT* current_logits = logits_data + (input_length - 1) * vocab_size;
497+
// Move the pointer in increments of padded_vocab_size to account for any padding
498+
// if any in the logits weight of the MatMul.
499+
const CudaT* current_logits = logits_data + (input_length - 1) * padded_vocab_size;
493500
for (int i = 0; i < batch_beam_size; i++) {
501+
// We only copy what is relevant (i.e.) vocab_size as padded_vocab_size will contain
502+
// some logits corresponding to the "padded" vocab size which we will ignore
503+
// for token generation.
494504
gsl::span<const T> source(reinterpret_cast<const T*>(current_logits), vocab_size);
495505
gsl::span<T> target = next_token_scores.subspan(i * vocab_size, vocab_size);
496506
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size,
497507
cudaMemcpyDeviceToDevice, cuda_stream));
498-
current_logits += input_length * vocab_size;
508+
current_logits += input_length * padded_vocab_size;
499509
}
500510

501511
#ifdef DEBUG_GENERATION

onnxruntime/core/graph/contrib_ops/contrib_defs.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1,
10871087
.Attr("model_type", "model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart", AttributeProto::INT, static_cast<int64_t>(0))
10881088
.Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE)
10891089
.Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH)
1090+
.Attr("vocab_size",
1091+
"Size of the vocabulary. "
1092+
"If not provided, it will be inferred from the decoder subgraph's output shape",
1093+
AttributeProto::INT, static_cast<int64_t>(-1))
10901094
.Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I")
10911095
.Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
10921096
.Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional)

onnxruntime/python/tools/transformers/convert_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -906,14 +906,14 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
906906
# We only want to pad the logits MatMul weight in the decoder for fp16 models.
907907
# The inherent assumption is that fp16 models run on GPU for which all
908908
# dims need to be a multiple of 8 to leverage tensor cores.
909-
# NOTE: We currently only support padding the MatMul logits weight for GPT2 BeamSearch.
909+
# NOTE: We currently only support padding the MatMul logits weight for GPT2 GreedySearch/BeamSearch.
910910
# This can be expanded to other models/decoding strategies later
911911
logits_matmul_weight_padded = False
912912
if (
913913
args.pad_vocab_size
914914
and args.precision == Precision.FLOAT16
915915
and is_gpt2
916-
and generation_type == GenerationType.BEAMSEARCH
916+
and (generation_type == GenerationType.BEAMSEARCH or generation_type == GenerationType.GREEDYSEARCH)
917917
):
918918
logger.info(
919919
f"Pad logits MatMul weights for optimal MatMul perf in fp16 on {args.decoder_onnx}. "
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include <memory>
5+
#include <vector>
6+
#include "gtest/gtest.h"
7+
#include "core/common/gsl.h"
8+
#include "core/session/onnxruntime_cxx_api.h"
9+
#include "test/common/cuda_op_test_utils.h"
10+
11+
extern std::unique_ptr<Ort::Env> ort_env;
12+
13+
namespace onnxruntime {
14+
namespace test {
15+
16+
TEST(GreedySearchTest, GptGreedySearchFp16_VocabPadded) {
17+
std::vector<int64_t> input_ids_shape{2, 4};
18+
std::vector<int32_t> input_ids{
19+
0, 0, 0, 52, 0, 0, 195, 731};
20+
21+
std::vector<int64_t> parameter_shape{1};
22+
std::vector<int32_t> max_length{10};
23+
std::vector<int32_t> min_length{1};
24+
std::vector<float> repetition_penalty{1.0f};
25+
26+
std::vector<int64_t> expected_output_shape{input_ids_shape[0], max_length[0]};
27+
28+
std::vector<int32_t> expected_output{
29+
0, 0, 0, 52, 204, 204, 204, 204, 204, 204,
30+
0, 0, 195, 731, 731, 114, 114, 114, 114, 114};
31+
32+
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
33+
auto input_ids_tensor = Ort::Value::CreateTensor(
34+
info, input_ids.data(), input_ids.size(), input_ids_shape.data(), input_ids_shape.size());
35+
36+
auto max_length_tensor = Ort::Value::CreateTensor(
37+
info, max_length.data(), max_length.size(), parameter_shape.data(), parameter_shape.size());
38+
39+
auto min_length_tensor = Ort::Value::CreateTensor(
40+
info, min_length.data(), min_length.size(), parameter_shape.data(), parameter_shape.size());
41+
42+
auto repetition_penalty_tensor = Ort::Value::CreateTensor(
43+
info, repetition_penalty.data(), repetition_penalty.size(), parameter_shape.data(), parameter_shape.size());
44+
45+
std::vector<Ort::Value> ort_inputs;
46+
ort_inputs.push_back(std::move(input_ids_tensor));
47+
ort_inputs.push_back(std::move(max_length_tensor));
48+
ort_inputs.push_back(std::move(min_length_tensor));
49+
ort_inputs.push_back(std::move(repetition_penalty_tensor));
50+
const char* input_names[] = {"input_ids", "max_length", "min_length", "repetition_penalty"};
51+
const char* const output_names[] = {"sequences"};
52+
53+
constexpr int min_cuda_architecture = 530;
54+
if (HasCudaEnvironment(min_cuda_architecture)) {
55+
Ort::SessionOptions session_options;
56+
#ifdef USE_CUDA
57+
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
58+
#endif
59+
60+
// The following model was obtained by padding the vocabulary size in testdata/transformers/tiny_gpt2_beamsearch_fp16.onnx
61+
// (by making beam_size == 1) from 1000 to 1600 (just for illustrative and testing purposes) to see if the beam search
62+
// implementation can handle such a scenario
63+
Ort::Session session(*ort_env, ORT_TSTR("testdata/transformers/tiny_gpt2_greedysearch_fp16_padded_vocab.onnx"), session_options);
64+
65+
auto ort_outputs = session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(),
66+
output_names, 1);
67+
68+
ASSERT_EQ(ort_outputs.size(), 1U);
69+
const auto& sequences = ort_outputs[0];
70+
ASSERT_TRUE(sequences.IsTensor());
71+
72+
auto result_ts = sequences.GetTensorTypeAndShapeInfo();
73+
ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, result_ts.GetElementType());
74+
75+
ASSERT_EQ(expected_output_shape, result_ts.GetShape());
76+
const auto* result_vals = sequences.GetTensorData<int32_t>();
77+
auto result_span = gsl::make_span(result_vals, expected_output.size());
78+
ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end()));
79+
}
80+
}
81+
} // namespace test
82+
} // namespace onnxruntime
343 KB
Binary file not shown.

0 commit comments

Comments
 (0)