Skip to content

Commit 02a0be3

Browse files
authored
Optimize Transpose around QLinearSoftmax (microsoft#22849)
### Description <!-- Describe your changes. --> - Improved Transpose around QLinearSoftmax in Level 3 NHWC Transformer. - Removed redundant code HandleQLinearConcat, HandleQLinearBinaryOp. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> By merging and eliminating redundant transpose , the Image Segmentation i8 model (MobileNetv2 + DeepLabv3) achieves a 2.34X speedup.
1 parent 135d8b2 commit 02a0be3

File tree

4 files changed

+50
-12
lines changed

4 files changed

+50
-12
lines changed

onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,14 +1654,14 @@ static bool HandleSplit(HandlerArgs& args) {
16541654

16551655
constexpr HandlerInfo split_handler = {&FirstInput, &HandleSplit};
16561656

1657-
static bool HandleConcat(HandlerArgs& args) {
1657+
bool HandleConcat(HandlerArgs& args) {
16581658
return HandleSimpleNodeWithAxis(args);
16591659
}
16601660

16611661
constexpr HandlerInfo concat_handler = {&AllInputs, &HandleConcat};
16621662

16631663
// Handles Softmax, Hardmax, and LogSoftmax
1664-
static bool HandleSoftHardMax(HandlerArgs& args) {
1664+
bool HandleSoftHardMax(HandlerArgs& args) {
16651665
if (args.ctx.opset >= 13) {
16661666
return HandleSimpleNodeWithAxis(args, /*default_axis*/ -1);
16671667
}

onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ bool HandleSimpleNodeBroadcast(HandlerArgs& args);
7171
// Transposes all inputs and all outputs. Updates axis attribute.
7272
bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional<int64_t> default_axis = std::nullopt);
7373

74+
bool HandleConcat(HandlerArgs& args);
75+
bool HandleSoftHardMax(HandlerArgs& args);
76+
7477
// base handlers that are used by extended handlers. add from transpose_optimizer.cc as needed.
7578
bool HandleReduceOps(HandlerArgs& args);
7679
bool HandleResize([[maybe_unused]] HandlerArgs& args);

onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ static bool EPAwareHandleResize(HandlerArgs& args) {
3434

3535
constexpr HandlerInfo ep_aware_resize_handler = {&FirstInput, &EPAwareHandleResize};
3636

37-
static bool HandleQLinearConcat(HandlerArgs& args) {
38-
return HandleSimpleNodeWithAxis(args);
39-
}
40-
4137
std::vector<size_t> QLinearConcatInputs(OptimizerCtx& ctx, api::NodeRef& node) {
4238
(void)ctx;
4339
std::vector<size_t> indices;
@@ -48,19 +44,15 @@ std::vector<size_t> QLinearConcatInputs(OptimizerCtx& ctx, api::NodeRef& node) {
4844
return indices;
4945
}
5046

51-
constexpr HandlerInfo q_linear_concat_handler = {&QLinearConcatInputs, &HandleQLinearConcat};
52-
53-
static bool HandleQLinearBinaryOp(HandlerArgs& args) {
54-
return HandleSimpleNodeBroadcast(args);
55-
}
47+
constexpr HandlerInfo q_linear_concat_handler = {&QLinearConcatInputs, &HandleConcat};
5648

5749
std::vector<size_t> QLinearBinaryOpInputs(OptimizerCtx&, api::NodeRef&) {
5850
// Inputs are: [A, A_scale, A_zero_point, B, B_scale, B_zero_point, C_scale, C_zero_point],
5951
// we want [A, B].
6052
return {0, 3};
6153
}
6254

63-
constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &HandleQLinearBinaryOp};
55+
constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &HandleSimpleNodeBroadcast};
6456

6557
static bool HandleQLinearPoolOp(HandlerArgs& args) {
6658
// Swap between channel first/last variants. Only works for applicable values of perm.
@@ -129,6 +121,7 @@ constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool};
129121

130122
constexpr HandlerInfo node_1_inp_handler = {&FirstInput, &HandleSimpleNode};
131123
constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps};
124+
constexpr HandlerInfo soft_hard_max_handler = {&FirstInput, &HandleSoftHardMax};
132125
constexpr HandlerInfo contrib_quantize_dequantize_linear_handler = {&FirstInput,
133126
&HandleContribQuantizeDequantizeLinear};
134127

@@ -148,6 +141,7 @@ const HandlerMap& OrtExtendedHandlers() {
148141
{"com.microsoft.QLinearMul", q_linear_binary_op_handler},
149142
{"com.microsoft.QLinearReduceMean", reduce_op_handler},
150143
{"com.microsoft.QLinearSigmoid", node_1_inp_handler},
144+
{"com.microsoft.QLinearSoftmax", soft_hard_max_handler},
151145
};
152146

153147
return map;

onnxruntime/test/optimizer/transpose_optimizer_test.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "test/optimizer/graph_transform_test_builder.h"
2323
#include "test/providers/internal_testing/internal_testing_execution_provider.h"
2424
#include "test/util/include/asserts.h"
25+
#include "test/util/include/default_providers.h"
2526
#include "test/util/include/inference_session_wrapper.h"
2627
#include "test/util/include/test_utils.h"
2728

@@ -3800,6 +3801,46 @@ TEST(TransposeOptimizerTests, TestCast) {
38003801
/*opset_version*/ {15, 18});
38013802
}
38023803

3804+
TEST(TransposeOptimizerTests, TestQLinearSoftmax) {
3805+
auto build_test_case_1 = [&](ModelTestBuilder& builder) {
3806+
auto* input0_arg = MakeInput<uint8_t>(builder, std::nullopt, {1, 384, 384, 21}, 0, 255);
3807+
auto* transpose_1_out_0 = builder.MakeIntermediate();
3808+
auto* input_x_scale = builder.MakeScalarInitializer<float>(0.5086354613304138);
3809+
auto* input_x_zero_point = builder.MakeScalarInitializer<uint8_t>(74);
3810+
auto* input_y_scale = builder.MakeScalarInitializer<float>(0.003921568859368563);
3811+
auto* input_y_zero_point = builder.MakeScalarInitializer<uint8_t>(0);
3812+
auto* qlinearsoftmax_1_out_0 = builder.MakeIntermediate();
3813+
auto* transpose_2_out_0 = builder.MakeOutput();
3814+
3815+
auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0});
3816+
transpose_1.AddAttribute("perm", std::vector<int64_t>{0, 3, 1, 2});
3817+
auto& qlinearsoftmax_1 = builder.AddNode("QLinearSoftmax",
3818+
{transpose_1_out_0, input_x_scale, input_x_zero_point, input_y_scale, input_y_zero_point},
3819+
{qlinearsoftmax_1_out_0}, kMSDomain);
3820+
qlinearsoftmax_1.AddAttribute("axis", static_cast<int64_t>(1));
3821+
qlinearsoftmax_1.AddAttribute("opset", static_cast<int64_t>(13));
3822+
auto& transpose_2 = builder.AddNode("Transpose", {qlinearsoftmax_1_out_0}, {transpose_2_out_0});
3823+
transpose_2.AddAttribute("perm", std::vector<int64_t>{0, 2, 3, 1});
3824+
};
3825+
3826+
auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) {
3827+
int transpose_cost = EstimateTransposeCost(session.GetGraph());
3828+
EXPECT_EQ(transpose_cost, 0);
3829+
};
3830+
3831+
TransformerTester(build_test_case_1,
3832+
check_optimized_graph_1,
3833+
TransformerLevel::Level2,
3834+
TransformerLevel::Level3,
3835+
/*opset_version*/ 13,
3836+
/*per_sample_tolerance*/ 0.0,
3837+
/*relative_per_sample_tolerance*/ 0.0,
3838+
/*transformer*/ nullptr,
3839+
/*add_session_options*/ {},
3840+
/*disabled_optimizers*/ {},
3841+
/*ep*/ DefaultCpuExecutionProvider());
3842+
}
3843+
38033844
TEST(TransposeOptimizerTests, TestBroadcastReusedInputs) {
38043845
auto build_test_case_1 = [&](ModelTestBuilder& builder) {
38053846
auto* input0_arg = MakeInput<float>(builder, {{-1, -1, 3, 4}}, {1, 2, 3, 4}, 0.0, 1.0);

0 commit comments

Comments
 (0)