Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
CreateExpandOpBuilder("Expand", *this);
}

{
CreateEinsumOpBuilder("Einsum", *this);
}

{
CreateMatMulOpBuilder("MatMul", *this);
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,7 @@ void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& o
void CreateHardSigmoidOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ Status ClipOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N
if (node_unit.Inputs().size() > 1) {
const auto& min_input_name = node_unit.Inputs()[1].node_arg.Name();
if (!min_input_name.empty() && !qnn_model_wrapper.IsConstantInput(min_input_name)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic min/max.");
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic min/max.");
}
}
if (node_unit.Inputs().size() > 2) {
const auto& max_input_name = node_unit.Inputs()[2].node_arg.Name();
if (!max_input_name.empty() && !qnn_model_wrapper.IsConstantInput(max_input_name)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic min/max.");
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic min/max.");
}
}
return Status::OK();
Expand Down
396 changes: 396 additions & 0 deletions onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Status SliceOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const
for (size_t i = 1; i < input_count; i++) {
const auto& next_input = node_unit.Inputs()[i].node_arg.Name();
if (!qnn_model_wrapper.IsConstantInput(next_input)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic slice.");
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic slice.");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ std::vector<uint32_t> FlattenShapeFromAxis(std::vector<uint32_t>& input_shape, i
Return the shape with all dimensions multiplied onward from the specified axis. If axis is 0, the returned shape
will include an additional batch of size 1 as the first dimension.
*/
assert(axis >= 0 && axis < input_shape.size());
assert(axis >= 0 && static_cast<size_t>(axis) < input_shape.size());
std::vector<uint32_t> output_shape(input_shape.begin(), input_shape.begin() + axis);

if (axis == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Status TileOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
std::vector<std::string>& input_names,
bool do_op_validation) const {
const auto& inputs = node_unit.Inputs();
// QNN Tile only support 1 input, the 2nd input need to be initialier and set as Qnn node parameter
// QNN Tile only support 1 input, the 2nd input need to be initializer and set as Qnn node parameter
if (do_op_validation) {
auto& repeats_input_name = inputs[1].node_arg.Name();
ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(repeats_input_name),
Expand All @@ -60,7 +60,7 @@ Status TileOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
const logging::Logger& logger,
bool do_op_validation) const {
std::vector<std::string> param_tensor_names;
// Already confirmed repeats input is initailizer in ProcessInputs()
// Already confirmed repeats input is initializer in ProcessInputs()
const auto& repeats_input_name = node_unit.Inputs()[1].node_arg.Name();

std::vector<uint8_t> unpacked_tensor;
Expand Down
10 changes: 6 additions & 4 deletions onnxruntime/core/providers/qnn/builder/qnn_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,16 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) {
auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors());

if (Status::OK() != result) {
LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name();
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN input tensors!");
const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name();
LOGS(logger, ERROR) << message;
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message);
}

result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false);
if (Status::OK() != result) {
LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name();
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN output tensors!");
const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name();
LOGS(logger, ERROR) << message;
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message);
}

return Status::OK();
Expand Down
190 changes: 190 additions & 0 deletions onnxruntime/test/providers/qnn/einsum_op_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if !defined(ORT_MINIMAL_BUILD)

#include <string>
#include <vector>

#include "test/providers/qnn/qnn_test_utils.h"
#include "core/graph/node_attr_utils.h"
#include "test/util/include/test_utils.h"

#include "core/graph/onnx_protobuf.h"
#include "gtest/gtest.h"

namespace {

using onnxruntime::ProviderOptions;
using onnxruntime::test::BuildOpTestCase;
using onnxruntime::test::ExpectedEPNodeAssignment;
using onnxruntime::test::RunQnnModelTest;
using onnxruntime::test::TestInputDef;
using onnxruntime::utils::MakeAttribute;

template <typename DataType>
static void RunQnnEinsum(
const std::string& backend,
const TestInputDef<DataType>& in0,
const TestInputDef<DataType>& in1,
const std::string& equation,
const float f32_abs_err = 1e-4f) {
ProviderOptions provider_options;
provider_options["backend_type"] = backend;
provider_options["offload_graph_io_quantization"] = "0";
RunQnnModelTest(
/*build_test_case=*/BuildOpTestCase<DataType, DataType>(
/*op_type=*/"Einsum",
/*input_defs_1=*/{in0, in1},
/*input_defs_2=*/{},
/*attrs=*/{MakeAttribute("equation", equation)}),
/*provider_options=*/provider_options,
/*opset_version=*/13,
/*expected_ep_assignment=*/ExpectedEPNodeAssignment::All,
/*f32_abs_err=*/f32_abs_err);
}

} // namespace

namespace onnxruntime {
namespace test {

//
// QNN CPU
//

TEST_F(QnnCPUBackendTests, EinsumRank2) {
const std::vector<int64_t> shape0{2, 3};
const std::vector<int64_t> shape1{3, 4};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/"cpu",
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"ab,bc->ac");
}

TEST_F(QnnCPUBackendTests, EinsumRank4MatMul) {
const std::vector<int64_t> shape0{3, 4, 5, 6};
const std::vector<int64_t> shape1{3, 4, 6, 5};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/"cpu",
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bhij,bhjd->bhid");
}

TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeY) {
const std::vector<int64_t> shape0{2, 3, 4, 6};
const std::vector<int64_t> shape1{2, 3, 5, 6};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/"cpu",
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bhid,bhjd->bhij");
}

TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) {
const std::vector<int64_t> shape0{1, 9, 1, 7};
const std::vector<int64_t> shape1{1, 7, 1, 9};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/"cpu",
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bchq,bkhc->bkhq");
}

TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) {
const std::vector<int64_t> shape0{1, 7, 1, 7};
const std::vector<int64_t> shape1{1, 9, 1, 7};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/"cpu",
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bkhq,bchk->bchq");
}

//
// QNN HTP
//

#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)

TEST_F(QnnHTPBackendTests, EinsumRank2MatMul) {
const std::vector<int64_t> shape0{2, 3};
const std::vector<int64_t> shape1{3, 4};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/"htp",
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"ij,jk->ik",
/*f32_abs_err=*/1e-2f);
}

TEST_F(QnnHTPBackendTests, EinsumRank4MatMul) {
const std::vector<int64_t> shape0{3, 1, 5, 2};
const std::vector<int64_t> shape1{3, 1, 2, 5};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/"htp",
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bhij,bhjd->bhid",
/*f32_abs_err=*/1e-2f);
}

TEST_F(QnnHTPBackendTests, EinsumRank4MatMulTransposeY) {
const std::vector<int64_t> shape0{2, 3, 4, 2};
const std::vector<int64_t> shape1{2, 3, 5, 2};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/"htp",
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bhid,bhjd->bhij",
/*f32_abs_err=*/1e-2f);
}

TEST_F(QnnHTPBackendTests, EinsumRank4MatMulTransposeAll1) {
const std::vector<int64_t> shape0{1, 3, 1, 7};
const std::vector<int64_t> shape1{1, 7, 1, 3};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/"htp",
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bchq,bkhc->bkhq",
/*f32_abs_err=*/1e-2f);
}

TEST_F(QnnHTPBackendTests, EinsumRank4MatMulTransposeAll2) {
const std::vector<int64_t> shape0{1, 4, 1, 4};
const std::vector<int64_t> shape1{1, 9, 1, 4};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/"htp",
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bkhq,bchk->bchq",
/*f32_abs_err=*/1e-2f);
}
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)

} // namespace test
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD)
Loading