Skip to content

Commit 97d8d90

Browse files
authored
[QNN-EP] Add Support for CumSum in QNN EP (#24820)
- Registered CumSum op in QNN EP - Added unit test to verify accuracy and assignment of op to QNN EP ### Description Added support for CumSum in QNN EP ### Motivation and Context There is no support for CumSum op in QNN EP
1 parent 9415b94 commit 97d8d90

File tree

8 files changed

+327
-0
lines changed

8 files changed

+327
-0
lines changed

onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,24 @@ bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& n
764764
return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath());
765765
}
766766

767+
bool CumSumNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
768+
const std::vector<const Node*>& dq_nodes,
769+
const std::vector<const Node*>& q_nodes) const {
770+
// Only the first input has DQ node
771+
if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, 1)) {
772+
return false;
773+
}
774+
775+
int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
776+
int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
777+
778+
if (dt_input != dt_output) {
779+
return false;
780+
}
781+
782+
return true;
783+
}
784+
767785
} // namespace QDQ
768786
} // namespace onnxruntime
769787

onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,14 @@ class TopKNodeGroupSelector : public NodeGroupSelector {
285285
const std::vector<const Node*>& q_nodes) const override;
286286
};
287287

288+
// one DQ node for first input -> node -> Q
289+
class CumSumNodeGroupSelector : public NodeGroupSelector {
290+
bool Check(const GraphViewer& graph_viewer,
291+
const Node& node, const Node* redundant_clip_node,
292+
const std::vector<const Node*>& dq_nodes,
293+
const std::vector<const Node*>& q_nodes) const override;
294+
};
295+
288296
/*
289297
* NodeSelector instances for use in the QDQ::SelectorActionTransformer.
290298
*/

onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetPadOpVersionsMap() {
146146
static const OpVersionsAndSelector::OpVersionsMap GetTopKOpVersionsMap() {
147147
return {{"TopK", {}}};
148148
}
149+
static const OpVersionsAndSelector::OpVersionsMap GetCumSumOpVersionsMap() {
150+
return {{"CumSum", {}}};
151+
}
149152

150153
/* Selector rules registration related */
151154
void RegisterMiscSelectors(Selectors& qdq_selectors) {
@@ -268,6 +271,13 @@ void RegisterTopKSelector(Selectors& qdq_selectors) {
268271
std::move(selector));
269272
}
270273

274+
void RegisterCumSumSelector(Selectors& qdq_selectors) {
275+
/* register selector for cumsum op */
276+
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<CumSumNodeGroupSelector>();
277+
qdq_selectors.RegisterSelector(GetCumSumOpVersionsMap(),
278+
std::move(selector));
279+
}
280+
271281
void SelectorManager::CreateSelectors() {
272282
RegisterMiscSelectors(qdq_selectors_);
273283
RegisterDropDQSelectors(qdq_selectors_);
@@ -286,6 +296,7 @@ void SelectorManager::CreateSelectors() {
286296
RegisterWhereSelectors(qdq_selectors_);
287297
RegisterPadSelectors(qdq_selectors_);
288298
RegisterTopKSelector(qdq_selectors_);
299+
RegisterCumSumSelector(qdq_selectors_);
289300
}
290301

291302
void SelectorManager::InitializeSelectorsMap() {

onnxruntime/core/providers/qnn/builder/op_builder_factory.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
185185
{
186186
CreateLSTMOpBuilder("LSTM", *this);
187187
}
188+
189+
{
190+
CreateCumSumOpBuilder("CumSum", *this);
191+
}
188192
}
189193

190194
const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) {

onnxruntime/core/providers/qnn/builder/op_builder_factory.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,8 @@ void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& o
104104
void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
105105

106106
void CreateLSTMOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
107+
108+
void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
109+
107110
} // namespace qnn
108111
} // namespace onnxruntime

onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ class BaseOpBuilder : public IOpBuilder {
230230

231231
{"LogSoftmax", QNN_OP_LOG_SOFTMAX},
232232
{"Concat", QNN_OP_CONCAT},
233+
{"CumSum", QNN_OP_CUMULATIVE_SUM},
233234

234235
{"Gemm", QNN_OP_FULLY_CONNECTED},
235236

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
5+
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
6+
#include "core/providers/qnn/builder/qnn_utils.h"
7+
#include "core/providers/qnn/builder/op_builder_factory.h"
8+
9+
namespace onnxruntime {
10+
namespace qnn {
11+
namespace {
12+
13+
Status GetOnnxAxis(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, uint32_t& onnx_axis) {
14+
const auto& inputs = node_unit.Inputs();
15+
TensorInfo axis_input_info = {};
16+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], axis_input_info));
17+
ORT_RETURN_IF_NOT(axis_input_info.is_initializer, "axis must be initializers");
18+
std::vector<uint8_t> axis_unpacked_tensor;
19+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*axis_input_info.initializer_tensor, axis_unpacked_tensor));
20+
ORT_RETURN_IF_NOT(1 == static_cast<uint32_t>(axis_unpacked_tensor.size() / sizeof(axis_input_info.qnn_data_type)),
21+
"axis should be a single element");
22+
23+
int32_t axis = 0;
24+
if (axis_input_info.qnn_data_type == QNN_DATATYPE_INT_64) {
25+
axis = static_cast<int32_t>(*reinterpret_cast<const int64_t*>(axis_unpacked_tensor.data()));
26+
} else {
27+
axis = static_cast<int32_t>(*reinterpret_cast<const int32_t*>(axis_unpacked_tensor.data()));
28+
}
29+
30+
std::vector<uint32_t> input_shape;
31+
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape");
32+
33+
auto rank = static_cast<int32_t>(input_shape.size());
34+
if (axis < 0) {
35+
axis += rank;
36+
}
37+
38+
ORT_RETURN_IF_NOT((axis >= 0 && axis < static_cast<int32_t>(input_shape.size())), "QNN requires axis range [0, rank-1].");
39+
40+
onnx_axis = static_cast<uint32_t>(axis);
41+
42+
return Status::OK();
43+
}
44+
45+
} // namespace
46+
47+
class CumSumOpBuilder : public BaseOpBuilder {
48+
public:
49+
CumSumOpBuilder() : BaseOpBuilder("CumSumOpBuilder") {}
50+
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CumSumOpBuilder);
51+
52+
Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
53+
const NodeUnit& node_unit,
54+
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
55+
56+
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
57+
const NodeUnit& node_unit,
58+
const logging::Logger& logger,
59+
std::vector<std::string>& input_names,
60+
bool do_op_validation) const override ORT_MUST_USE_RESULT;
61+
62+
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
63+
const NodeUnit& node_unit,
64+
std::vector<std::string>&& input_names,
65+
const logging::Logger& logger,
66+
bool do_op_validation) const override ORT_MUST_USE_RESULT;
67+
};
68+
69+
Status CumSumOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
70+
const NodeUnit& node_unit,
71+
const logging::Logger& logger) const {
72+
const auto& inputs = node_unit.Inputs();
73+
ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(inputs[1].node_arg.Name()),
74+
"QNN CumSum needs axis as a param, hence input[1] must be a constant.");
75+
76+
NodeAttrHelper node_helper(node_unit);
77+
int64_t exclusive = node_helper.Get("exclusive", static_cast<int64_t>(0));
78+
int64_t reverse = node_helper.Get("reverse", static_cast<int64_t>(0));
79+
80+
// QNN HTP op validation passes for non-default values of attributes but fails in finalize.
81+
// Hence adding the checks here.
82+
ORT_RETURN_IF_NOT(exclusive == 0, "QNN only supports default value 0 for exclusive attribute");
83+
ORT_RETURN_IF_NOT(reverse == 0, "QNN only supports default value 0 for reverse attribute");
84+
85+
return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true);
86+
}
87+
88+
Status CumSumOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
89+
const NodeUnit& node_unit,
90+
const logging::Logger& logger,
91+
std::vector<std::string>& input_names,
92+
bool do_op_validation) const {
93+
ORT_UNUSED_PARAMETER(do_op_validation);
94+
const auto& inputs = node_unit.Inputs();
95+
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names));
96+
return Status::OK();
97+
}
98+
99+
Status CumSumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
100+
const NodeUnit& node_unit,
101+
std::vector<std::string>&& input_names,
102+
const logging::Logger& logger,
103+
bool do_op_validation) const {
104+
ORT_UNUSED_PARAMETER(do_op_validation);
105+
106+
std::vector<std::string> param_tensor_names;
107+
108+
// Add axis param
109+
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
110+
uint32_t onnx_axis = 0;
111+
ORT_RETURN_IF_ERROR(GetOnnxAxis(qnn_model_wrapper, node_unit, onnx_axis));
112+
axis_qnn_scalar.dataType = QNN_DATATYPE_UINT_32;
113+
axis_qnn_scalar.uint32Value = onnx_axis;
114+
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_CUMULATIVE_SUM_PARAM_AXIS, axis_qnn_scalar);
115+
param_tensor_names.push_back(axis_param.GetParamTensorName());
116+
qnn_model_wrapper.AddParamWrapper(std::move(axis_param));
117+
118+
// Add exclusive param
119+
NodeAttrHelper node_helper(node_unit);
120+
int64_t exclusive = node_helper.Get("exclusive", static_cast<int64_t>(0));
121+
Qnn_Scalar_t exclusive_qnn_scalar = QNN_SCALAR_INIT;
122+
exclusive_qnn_scalar.dataType = QNN_DATATYPE_BOOL_8;
123+
exclusive_qnn_scalar.bool8Value = static_cast<uint8_t>(exclusive == 0 ? 0 : 1);
124+
QnnParamWrapper exclusive_param(node_unit.Index(), node_unit.Name(), QNN_OP_CUMULATIVE_SUM_PARAM_EXCLUSIVE, exclusive_qnn_scalar);
125+
param_tensor_names.push_back(exclusive_param.GetParamTensorName());
126+
qnn_model_wrapper.AddParamWrapper(std::move(exclusive_param));
127+
128+
// Add reverse param
129+
int64_t reverse = node_helper.Get("reverse", static_cast<int64_t>(0));
130+
Qnn_Scalar_t reverse_qnn_scalar = QNN_SCALAR_INIT;
131+
reverse_qnn_scalar.dataType = QNN_DATATYPE_BOOL_8;
132+
reverse_qnn_scalar.bool8Value = static_cast<uint8_t>(reverse == 0 ? 0 : 1);
133+
QnnParamWrapper reverse_param(node_unit.Index(), node_unit.Name(), QNN_OP_CUMULATIVE_SUM_PARAM_REVERSE, reverse_qnn_scalar);
134+
param_tensor_names.push_back(reverse_param.GetParamTensorName());
135+
qnn_model_wrapper.AddParamWrapper(std::move(reverse_param));
136+
137+
return ProcessOutputs(qnn_model_wrapper, node_unit,
138+
std::move(input_names),
139+
std::move(param_tensor_names),
140+
logger, do_op_validation, GetQnnOpType(node_unit.OpType()));
141+
}
142+
143+
void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
144+
op_registrations.AddOpBuilder(op_type, std::make_unique<CumSumOpBuilder>());
145+
}
146+
147+
} // namespace qnn
148+
} // namespace onnxruntime
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#if !defined(ORT_MINIMAL_BUILD)
5+
6+
#include <string>
7+
#include "core/graph/graph.h"
8+
#include "core/graph/node_attr_utils.h"
9+
10+
#include "test/providers/qnn/qnn_test_utils.h"
11+
12+
#include "gtest/gtest.h"
13+
14+
namespace onnxruntime {
15+
namespace test {
16+
#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
17+
18+
// Runs a non-QDQ model on HTP and compares output to CPU EP.
19+
template <typename InputType1 = float, typename InputType2 = float>
20+
static void RunCumSumOpTest(const std::string& op_type,
21+
const TestInputDef<InputType1>& input_def_1,
22+
const TestInputDef<InputType2>& input_def_2,
23+
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs,
24+
int opset_version,
25+
ExpectedEPNodeAssignment expected_ep_assignment,
26+
float fp32_abs_err = 2e-3f) {
27+
ProviderOptions provider_options;
28+
provider_options["backend_type"] = "htp";
29+
provider_options["offload_graph_io_quantization"] = "0";
30+
31+
// Runs model with a Q/DQ binary op and compares the outputs of the CPU and QNN EPs.
32+
RunQnnModelTest(BuildOpTestCase<InputType1, InputType2>(op_type, {input_def_1}, {input_def_2}, attrs),
33+
provider_options,
34+
opset_version,
35+
expected_ep_assignment,
36+
fp32_abs_err);
37+
}
38+
39+
// Non-QDQ model, CumSum with float input and axis input as initializer with axis 0
40+
TEST_F(QnnHTPBackendTests, CumSum_float_int32_e0_r0_axis_0) {
41+
RunCumSumOpTest<float, int32_t>("CumSum",
42+
TestInputDef<float>({3, 2}, false, {1.3f, 7.2f, 0.4f, 3.4f, 5.7f, 0.8f}),
43+
TestInputDef<int32_t>({}, true, {0}),
44+
{utils::MakeAttribute("exclusive", static_cast<int64_t>(0)),
45+
utils::MakeAttribute("reverse", static_cast<int64_t>(0))},
46+
17,
47+
ExpectedEPNodeAssignment::All);
48+
}
49+
50+
// Non-QDQ model, CumSum with float input and axis input as initializer with axis -1
51+
TEST_F(QnnHTPBackendTests, CumSum_float_int32_e0_r0_axis_neg1) {
52+
RunCumSumOpTest<float, int32_t>("CumSum",
53+
TestInputDef<float>({3, 2}, false, {1.3f, 7.2f, 0.4f, 3.4f, 5.7f, 0.8f}),
54+
TestInputDef<int32_t>({}, true, {-1}),
55+
{utils::MakeAttribute("exclusive", static_cast<int64_t>(0)),
56+
utils::MakeAttribute("reverse", static_cast<int64_t>(0))},
57+
17,
58+
ExpectedEPNodeAssignment::All);
59+
}
60+
61+
// Returns a function that creates a graph with a QDQ CumSum operator.
62+
template <typename QuantType, typename AxisType>
63+
GetTestQDQModelFn<QuantType> BuildQDQCumSumTestCase(const TestInputDef<float>& input_def,
64+
const TestInputDef<AxisType>& axis_def,
65+
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs,
66+
bool use_contrib_qdq = false) {
67+
return [input_def, axis_def, attrs, use_contrib_qdq](ModelTestBuilder& builder,
68+
std::vector<QuantParams<QuantType>>& output_qparams) {
69+
// input -> Q -> DQ ->
70+
NodeArg* input = MakeTestInput(builder, input_def);
71+
QuantParams<QuantType> input_qparams = GetTestInputQuantParams<QuantType>(input_def);
72+
NodeArg* input_qdq = AddQDQNodePair<QuantType>(builder, input, input_qparams.scale, input_qparams.zero_point,
73+
use_contrib_qdq);
74+
75+
// axis input
76+
NodeArg* axis_input = MakeTestInput(builder, axis_def);
77+
78+
// CumSum op
79+
NodeArg* op_output = builder.MakeIntermediate();
80+
Node& cumsum_node = builder.AddNode("CumSum", {input_qdq, axis_input}, {op_output});
81+
82+
for (const auto& attr : attrs) {
83+
cumsum_node.AddAttributeProto(attr);
84+
}
85+
86+
// op_output -> Q -> DQ -> output
87+
AddQDQNodePairWithOutputAsGraphOutput<QuantType>(builder, op_output, output_qparams[0].scale,
88+
output_qparams[0].zero_point, use_contrib_qdq);
89+
};
90+
}
91+
92+
// Test the accuracy of a QDQ CumSum model on QNN EP. Checks if the QDQ model on QNN EP is as accurate as the QDQ model on CPU EP
93+
// (compared to float32 model).
94+
template <typename QuantType, typename AxisType>
95+
static void RunQDQCumSumOpTest(const TestInputDef<float>& input_def,
96+
const TestInputDef<AxisType>& axis_def,
97+
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs,
98+
int opset,
99+
ExpectedEPNodeAssignment expected_ep_assignment,
100+
bool use_contrib_qdq = false) {
101+
ProviderOptions provider_options;
102+
provider_options["backend_type"] = "htp";
103+
provider_options["offload_graph_io_quantization"] = "0";
104+
105+
auto f32_model_builder = BuildOpTestCase<float, AxisType>("CumSum", {input_def}, {axis_def}, attrs);
106+
auto qdq_model_builder = BuildQDQCumSumTestCase<QuantType, AxisType>(input_def, axis_def, attrs,
107+
use_contrib_qdq);
108+
109+
TestQDQModelAccuracy<QuantType>(f32_model_builder,
110+
qdq_model_builder,
111+
provider_options,
112+
opset,
113+
expected_ep_assignment);
114+
}
115+
116+
// Test creates a DQ -> CumSum -> Q -> DQ graph, and checks that all
117+
// nodes are supported by the QNN EP, and that the inference results are as accurate as CPU EP.
118+
//
119+
// QDQ model, CumSum with uint8 input and axis input as initializer
120+
TEST_F(QnnHTPBackendTests, CumSum_uint8_int32_e0_r0) {
121+
RunQDQCumSumOpTest<uint8_t, int32_t>(TestInputDef<float>({3, 2}, false, {1.3f, 7.2f, 0.4f, 3.4f, 5.7f, 0.8f}),
122+
TestInputDef<int32_t>({}, true, {0}),
123+
{utils::MakeAttribute("exclusive", static_cast<int64_t>(0)),
124+
utils::MakeAttribute("reverse", static_cast<int64_t>(0))},
125+
17,
126+
ExpectedEPNodeAssignment::All);
127+
}
128+
129+
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
130+
131+
} // namespace test
132+
} // namespace onnxruntime
133+
134+
#endif

0 commit comments

Comments
 (0)