From b23cb9536a7bdbdaa3fe0fc554b0e8af08de2019 Mon Sep 17 00:00:00 2001 From: Yuduo Date: Thu, 1 May 2025 13:39:16 -0700 Subject: [PATCH 1/2] [QNN EP] Add Einsum support for some equations --- .../qnn/builder/op_builder_factory.cc | 4 + .../qnn/builder/op_builder_factory.h | 2 + .../qnn/builder/opbuilder/clip_op_builder.cc | 4 +- .../builder/opbuilder/einsum_op_builder.cc | 396 ++++++++++++++++++ .../qnn/builder/opbuilder/slice_op_builder.cc | 2 +- .../builder/opbuilder/softmax_op_builder.cc | 2 +- .../qnn/builder/opbuilder/tile_op_builder.cc | 4 +- .../core/providers/qnn/builder/qnn_model.cc | 10 +- .../test/providers/qnn/einsum_op_test.cc | 190 +++++++++ 9 files changed, 604 insertions(+), 10 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc create mode 100644 onnxruntime/test/providers/qnn/einsum_op_test.cc diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index 731cb30b74429..a3b823d8b3580 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -173,6 +173,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateExpandOpBuilder("Expand", *this); } + { + CreateEinsumOpBuilder("Einsum", *this); + } + { CreateMatMulOpBuilder("MatMul", *this); } diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index 8366e4e57e9d4..aa1039f857f8e 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -100,5 +100,7 @@ void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& o void CreateHardSigmoidOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + +void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc index 193b507083360..a1a658d5d963c 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc @@ -94,13 +94,13 @@ Status ClipOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N if (node_unit.Inputs().size() > 1) { const auto& min_input_name = node_unit.Inputs()[1].node_arg.Name(); if (!min_input_name.empty() && !qnn_model_wrapper.IsConstantInput(min_input_name)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic min/max."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic min/max."); } } if (node_unit.Inputs().size() > 2) { const auto& max_input_name = node_unit.Inputs()[2].node_arg.Name(); if (!max_input_name.empty() && !qnn_model_wrapper.IsConstantInput(max_input_name)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic min/max."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic min/max."); } } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc new file mode 100644 index 0000000000000..9db0b5202dcd4 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc @@ -0,0 +1,396 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/cpu/tensor/slice_helper.h" + +namespace { + +// Represented as a tuple of 3 strings . +// The equation string is expected to follow the format "term_1,term_2->result" +using Equation = std::tuple; + +/** + * @brief Parses an equation string into its components if it adheres to the expected format. + * + * @param equation_string The input equation string to parse. + * @return A std::optional containing a tuple of 3 strings (term_1, term_2, result) if the parsing is successful. + * Returns std::nullopt if the input string is invalid or does not conform to the expected format. + */ +std::optional ParseEquation(std::string_view equation_string) { + std::string equation(equation_string); + equation.erase(std::remove(equation.begin(), equation.end(), ' '), + equation.end()); + if (equation.empty()) { + return std::nullopt; + } + auto index_arrow = equation.find("->"); + if (index_arrow == std::string::npos) { + return std::nullopt; + } + const std::string lhs = equation.substr(0, index_arrow); + const std::string result = equation.substr(index_arrow + 2); + if (lhs.empty() || result.empty()) { + return std::nullopt; + } + auto index_comma = lhs.find(","); + if (index_comma == std::string::npos) { + return std::nullopt; + } + const std::string term_1 = lhs.substr(0, index_comma); + const std::string term_2 = lhs.substr(index_comma + 1); + if (term_1.empty() || term_2.empty()) { + return std::nullopt; + } + if (term_1.size() < 2) { + return std::nullopt; + } + if (term_1.size() != term_2.size()) { + return std::nullopt; + } + if (term_1.size() != result.size()) { + return std::nullopt; + } + if (!std::all_of(term_1.begin(), term_1.end(), [](unsigned char c) { return std::islower(c); })) { + return std::nullopt; + } + if (!std::all_of(term_2.begin(), term_2.end(), [](unsigned char c) { return std::islower(c); })) { + return std::nullopt; + } + if (!std::all_of(result.begin(), result.end(), [](unsigned char c) { return std::islower(c); })) { + return std::nullopt; + } + return std::make_tuple(term_1, term_2, result); +} + +bool IsEquationMatMul(const Equation& equation) { + // MatMul: e.g., "ij,jk->ik" + const auto& [term_1, term_2, result] = equation; + const size_t num_dims = term_1.size(); + for (size_t i = 0; i < num_dims; ++i) { + if (i >= num_dims - 2) { + continue; + } + if (!(term_1[i] == term_2[i] && term_1[i] == result[i])) { + return false; + } + } + char term_1_m = term_1[num_dims - 2]; + char term_2_k = term_2[num_dims - 2]; + char result_m = result[num_dims - 2]; + char term_1_k = term_1[num_dims - 1]; + char term_2_n = term_2[num_dims - 1]; + char result_n = result[num_dims - 1]; + if (term_1_m != result_m) { + return false; + } + if (term_1_k != term_2_k) { + return false; + } + if (term_2_n != result_n) { + return false; + } + return true; +} + +bool IsEquationMatMulTransposeY(const Equation& equation) { + // MatMul with 2nd input transposed: e.g., "id,jd->ij" + const auto& [term_1, term_2, result] = equation; + const size_t num_dims = term_1.size(); + for (size_t i = 0; i < num_dims; ++i) { + if (i >= num_dims - 2) { + continue; + } + if (!(term_1[i] == term_2[i] && term_1[i] == result[i])) { + return false; + } + } + char term_1_m = term_1[num_dims - 2]; + char term_2_k = term_2[num_dims - 2]; + char result_m = result[num_dims - 2]; + char term_1_k = term_1[num_dims - 1]; + char term_2_n = term_2[num_dims - 1]; + char result_n = result[num_dims - 1]; + if (term_1_m != result_m) { + return false; + } + if (term_1_k != term_2_n) { + return false; + } + if (term_2_k != result_n) { + return false; + } + return true; +} + +bool IsEquationMatMulTransposeAll(const Equation& equation) { + // MatMul transpose both inputs and output, e.g., "bchq,bkhc->bkhq", "bkhq,bchk->bchq" + const auto& [term_1, term_2, result] = equation; + const size_t num_dims = term_1.size(); + if (num_dims != 4) { + return false; + } + if (term_1[0] != term_2[0] || term_1[0] != result[0]) { + return false; + } + char term_1_m = term_1[num_dims - 1]; + char term_1_k = term_1[num_dims - 3]; + char term_2_k = term_2[num_dims - 1]; + char term_2_n = term_2[num_dims - 3]; + char result_m = result[num_dims - 1]; + char result_n = result[num_dims - 3]; + if (term_1_m != result_m) { + return false; + } + if (term_1_k != term_2_k) { + return false; + } + if (term_2_n != result_n) { + return false; + } + return true; +} + +/** + * @brief Sets the parameter tensor names for a MatMul op. + * + * @param qnn_model_wrapper Pointer to the QnnModelWrapper instance that manages the QNN model. + * @param node_unit Reference to the NodeUnit representing the ONNX node for which the parameters are being set. + * @param transpose_in0 Boolean flag indicating whether the 1st input tensor should be transposed (default: false). + * @param transpose_in1 Boolean flag indicating whether the 2nd input tensor should be transposed (default: false). + * @return A vector of strings containing the names of the parameter tensors added to the QNN model. + */ +std::vector SetMatMulParamTensorNames( + onnxruntime::qnn::QnnModelWrapper* qnn_model_wrapper, + const onnxruntime::NodeUnit& node_unit, + bool transpose_in0 = false, + bool transpose_in1 = false) { + std::vector param_tensor_names; + Qnn_Scalar_t scalar_params[2] = {QNN_SCALAR_INIT, QNN_SCALAR_INIT}; + scalar_params[0].dataType = QNN_DATATYPE_BOOL_8; + scalar_params[1].dataType = QNN_DATATYPE_BOOL_8; + scalar_params[0].bool8Value = static_cast(transpose_in0); + scalar_params[1].bool8Value = static_cast(transpose_in1); + onnxruntime::qnn::QnnParamWrapper transpose_in0_param( + node_unit.Index(), node_unit.Name(), QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN0, scalar_params[0]); + onnxruntime::qnn::QnnParamWrapper transpose_in1_param( + node_unit.Index(), node_unit.Name(), QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN1, scalar_params[1]); + param_tensor_names.push_back(transpose_in0_param.GetParamTensorName()); + param_tensor_names.push_back(transpose_in1_param.GetParamTensorName()); + qnn_model_wrapper->AddParamWrapper(std::move(transpose_in0_param)); + qnn_model_wrapper->AddParamWrapper(std::move(transpose_in1_param)); + return param_tensor_names; +} + +/** + * @brief Creates a MatMul operation with transposed inputs and output in a QNN model. + * + * @param qnn_model_wrapper Pointer to the QnnModelWrapper instance used to manage the QNN model. + * @param node_unit The NodeUnit representing the ONNX node to be converted. + * @param do_op_validation A boolean flag indicating whether to perform operation validation. + * @return Status indicating success or failure of the operation. + */ +Status CreateMatMulTransposeAll( + onnxruntime::qnn::QnnModelWrapper* qnn_model_wrapper, + const onnxruntime::NodeUnit& node_unit, + std::vector&& input_names, + bool do_op_validation) { + onnxruntime::qnn::TensorInfo input_info0{}, input_info1{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[0], input_info0)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[1], input_info1)); + std::vector input_shape0(input_info0.shape); + std::vector input_shape1(input_info1.shape); + std::swap(input_shape0[1], input_shape0[2]); + std::swap(input_shape1[1], input_shape1[2]); + const std::string input_transpos0 = input_names[0] + "_t0"; + const std::string input_transpos1 = input_names[1] + "_t1"; + const std::vector transpose_perm{0, 2, 1, 3}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddTransposeNode( + /*node_index=*/node_unit.Index(), + /*input_name=*/input_names[0], + /*output_name=*/input_transpos0, + /*input_shape=*/input_info0.shape, + /*transpose_perm=*/transpose_perm, + /*output_shape=*/input_shape0, + /*qnn_data_type=*/input_info0.qnn_data_type, + /*quantize_param=*/input_info0.quant_param.Copy(), + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/qnn_model_wrapper->IsGraphInput(input_names[0]))); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddTransposeNode( + /*node_index=*/node_unit.Index(), + /*input_name=*/input_names[1], + /*output_name=*/input_transpos1, + /*input_shape=*/input_info1.shape, + /*transpose_perm=*/transpose_perm, + /*output_shape=*/input_shape1, + /*qnn_data_type=*/input_info1.qnn_data_type, + /*quantize_param=*/input_info1.quant_param.Copy(), + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/qnn_model_wrapper->IsGraphInput(input_names[1]))); + onnxruntime::qnn::TensorInfo matmul_output_info{}; + const auto& output = node_unit.Outputs()[0]; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(output, matmul_output_info)); + const std::string matmul_output_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_matmul"; + std::vector matmul_output_shape(matmul_output_info.shape); + std::swap(matmul_output_shape[1], matmul_output_shape[2]); + onnxruntime::qnn::QnnTensorWrapper matmul_output_wrapper( + matmul_output_name, QNN_TENSOR_TYPE_NATIVE, matmul_output_info.qnn_data_type, + matmul_output_info.quant_param.Copy(), std::vector(matmul_output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(matmul_output_wrapper)), + node_unit.OpType() + " failed to add tensor."); + std::vector param_tensor_names = SetMatMulParamTensorNames( + qnn_model_wrapper, node_unit, /*transpose_in0=*/false, /*transpose_in1=*/false); + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(/*qnn_node_name=*/onnxruntime::qnn::utils::GetNodeName(node_unit), + /*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW, + /*qnn_node_type=*/QNN_OP_MAT_MUL, + /*input_names=*/{input_transpos1, input_transpos0}, + /*output_names=*/{matmul_output_name}, + /*param_tensor_names=*/std::move(param_tensor_names), + /*do_op_validation=*/do_op_validation), + node_unit.OpType() + " failed to add node."); + std::vector transpose_output_shape(matmul_output_info.shape); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddTransposeNode( + /*node_index=*/node_unit.Index(), + /*input_name=*/matmul_output_name, + /*output_name=*/output.node_arg.Name(), + /*input_shape=*/std::move(matmul_output_shape), + /*transpose_perm=*/transpose_perm, + /*output_shape=*/matmul_output_info.shape, + /*tensor_data_type=*/matmul_output_info.qnn_data_type, + /*quantize_param=*/matmul_output_info.quant_param.Copy(), + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/qnn_model_wrapper->IsGraphInput(output.node_arg.Name()), + /*is_for_output=*/qnn_model_wrapper->IsGraphOutput(output.node_arg.Name()))); + return Status::OK(); +} + +} // namespace + +namespace onnxruntime { +namespace qnn { + +class EinsumOpBuilder : public BaseOpBuilder { + public: + EinsumOpBuilder() : BaseOpBuilder("EinsumOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(EinsumOpBuilder); + + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + QnnQuantParamsWrapper& quant_param) const override ORT_MUST_USE_RESULT; +}; + +Status EinsumOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const { + if (node_unit.Inputs().size() < 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " requires at least 2 inputs."); + } + NodeAttrHelper node_helper{node_unit}; + const std::string equation = node_helper.Get("equation", std::string("")); + std::optional parsed_equation = ParseEquation(equation); + if (!parsed_equation.has_value()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); + } + if (!IsEquationMatMul(parsed_equation.value()) && + !IsEquationMatMulTransposeY(parsed_equation.value()) && + !IsEquationMatMulTransposeAll(parsed_equation.value())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); + } + return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); +} + +Status EinsumOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + const auto& inputs = node_unit.Inputs(); + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[1], logger, input_names)); + return Status::OK(); +} + +Status EinsumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + NodeAttrHelper node_helper(node_unit); + const std::string equation = node_helper.Get("equation", std::string("")); + std::optional parsed_equation = ParseEquation(equation); + if (IsEquationMatMul(parsed_equation.value())) { + std::vector param_tensor_names = SetMatMulParamTensorNames( + &qnn_model_wrapper, node_unit, /*transpose_in0=*/false, /*transpose_in1=*/false); + ORT_RETURN_IF_ERROR(ProcessOutputs(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_names=*/std::move(input_names), + /*param_tensor_names=*/std::move(param_tensor_names), + /*logger=*/logger, + /*do_op_validation=*/do_op_validation, + /*qnn_op_type=*/QNN_OP_MAT_MUL)); + } else if (IsEquationMatMulTransposeY(parsed_equation.value())) { + std::vector param_tensor_names = SetMatMulParamTensorNames( + &qnn_model_wrapper, node_unit, /*transpose_in0=*/false, /*transpose_in1=*/true); + ORT_RETURN_IF_ERROR(ProcessOutputs(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_names=*/std::move(input_names), + /*param_tensor_names=*/std::move(param_tensor_names), + /*logger=*/logger, + /*do_op_validation=*/do_op_validation, + /*qnn_op_type=*/QNN_OP_MAT_MUL)); + } else if (IsEquationMatMulTransposeAll(parsed_equation.value())) { + ORT_RETURN_IF_ERROR(CreateMatMulTransposeAll(&qnn_model_wrapper, node_unit, std::move(input_names), do_op_validation)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); + } + return Status::OK(); +} + +Status EinsumOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + QnnQuantParamsWrapper& quant_param) const { + if (!quant_param.IsPerTensor()) { + return Status::OK(); + } + + // Force the operator output to use the same quantization parameters as the input if nearly equal. + // This helps the HTP backend employ certain optimizations. + return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, + 0 /*input_index*/, output_index, qnn_data_type, quant_param); +} + +void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc index 19e5ee298f5fb..bcf4df8186dd2 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc @@ -46,7 +46,7 @@ Status SliceOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const for (size_t i = 1; i < input_count; i++) { const auto& next_input = node_unit.Inputs()[i].node_arg.Name(); if (!qnn_model_wrapper.IsConstantInput(next_input)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic slice."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic slice."); } } } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc index e8acaf75143d8..ffa3413a84889 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc @@ -44,7 +44,7 @@ std::vector FlattenShapeFromAxis(std::vector& input_shape, i Return the shape with all dimensions multiplied onward from the specified axis. If axis is 0, the returned shape will include an additional batch of size 1 as the first dimension. */ - assert(axis >= 0 && axis < input_shape.size()); + assert(axis >= 0 && static_cast(axis) < input_shape.size()); std::vector output_shape(input_shape.begin(), input_shape.begin() + axis); if (axis == 0) { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc index 555992ef00bfe..cba1faaa4fa2d 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc @@ -42,7 +42,7 @@ Status TileOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, std::vector& input_names, bool do_op_validation) const { const auto& inputs = node_unit.Inputs(); - // QNN Tile only support 1 input, the 2nd input need to be initialier and set as Qnn node parameter + // QNN Tile only support 1 input, the 2nd input need to be initializer and set as Qnn node parameter if (do_op_validation) { auto& repeats_input_name = inputs[1].node_arg.Name(); ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(repeats_input_name), @@ -60,7 +60,7 @@ Status TileOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra const logging::Logger& logger, bool do_op_validation) const { std::vector param_tensor_names; - // Already confirmed repeats input is initailizer in ProcessInputs() + // Already confirmed repeats input is initializer in ProcessInputs() const auto& repeats_input_name = node_unit.Inputs()[1].node_arg.Name(); std::vector unpacked_tensor; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 60e0f81aecc1f..8421bd4a99196 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -180,14 +180,16 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors()); if (Status::OK() != result) { - LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN input tensors!"); + const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name(); + LOGS(logger, ERROR) << message; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false); if (Status::OK() != result) { - LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN output tensors!"); + const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name(); + LOGS(logger, ERROR) << message; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } return Status::OK(); diff --git a/onnxruntime/test/providers/qnn/einsum_op_test.cc b/onnxruntime/test/providers/qnn/einsum_op_test.cc new file mode 100644 index 0000000000000..11606cd2b4c68 --- /dev/null +++ b/onnxruntime/test/providers/qnn/einsum_op_test.cc @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" +#include "test/util/include/test_utils.h" + +#include "core/graph/onnx_protobuf.h" +#include "gtest/gtest.h" + +namespace { + +using onnxruntime::ProviderOptions; +using onnxruntime::test::BuildOpTestCase; +using onnxruntime::test::ExpectedEPNodeAssignment; +using onnxruntime::test::RunQnnModelTest; +using onnxruntime::test::TestInputDef; +using onnxruntime::utils::MakeAttribute; + +template +static void RunQnnEinsum( + const std::string& backend, + const TestInputDef& in0, + const TestInputDef& in1, + const std::string& equation, + const float f32_abs_err = 1e-4f) { + ProviderOptions provider_options; + provider_options["backend_type"] = backend; + provider_options["offload_graph_io_quantization"] = "0"; + RunQnnModelTest( + /*build_test_case=*/BuildOpTestCase( + /*op_type=*/"Einsum", + /*input_defs_1=*/{in0, in1}, + /*input_defs_2=*/{}, + /*attrs=*/{MakeAttribute("equation", equation)}), + /*provider_options=*/provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*f32_abs_err=*/f32_abs_err); +} + +} // namespace + +namespace onnxruntime { +namespace test { + +// +// QNN CPU +// + +TEST_F(QnnCPUBackendTests, EinsumRank2) { + const std::vector shape0{2, 3}; + const std::vector shape1{3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/"cpu", + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"ab,bc->ac"); +} + +TEST_F(QnnCPUBackendTests, EinsumRank4MatMul) { + const std::vector shape0{3, 4, 5, 6}; + const std::vector shape1{3, 4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/"cpu", + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhij,bhjd->bhid"); +} + +TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeY) { + const std::vector shape0{2, 3, 4, 6}; + const std::vector shape1{2, 3, 5, 6}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/"cpu", + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhid,bhjd->bhij"); +} + +TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) { + const std::vector shape0{1, 9, 1, 7}; + const std::vector shape1{1, 7, 1, 9}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/"cpu", + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bchq,bkhc->bkhq"); +} + +TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) { + const std::vector shape0{1, 7, 1, 7}; + const std::vector shape1{1, 9, 1, 7}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/"cpu", + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bkhq,bchk->bchq"); +} + +// +// QNN HTP +// + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +TEST_F(QnnHTPBackendTests, EinsumRank2MatMul) { + const std::vector shape0{2, 3}; + const std::vector shape1{3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/"htp", + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"ij,jk->ik", + /*f32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, EinsumRank4MatMul) { + const std::vector shape0{3, 1, 5, 2}; + const std::vector shape1{3, 1, 2, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/"htp", + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhij,bhjd->bhid", + /*f32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, EinsumRank4MatMulTransposeY) { + const std::vector shape0{2, 3, 4, 2}; + const std::vector shape1{2, 3, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/"htp", + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhid,bhjd->bhij", + /*f32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, EinsumRank4MatMulTransposeAll1) { + const std::vector shape0{1, 3, 1, 7}; + const std::vector shape1{1, 7, 1, 3}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/"htp", + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bchq,bkhc->bkhq", + /*f32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, EinsumRank4MatMulTransposeAll2) { + const std::vector shape0{1, 4, 1, 4}; + const std::vector shape1{1, 9, 1, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/"htp", + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bkhq,bchk->bchq", + /*f32_abs_err=*/1e-2f); +} +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) From 8b6c806cff445b75f2ea5e5c216527c57082418f Mon Sep 17 00:00:00 2001 From: Yuduo Date: Thu, 1 May 2025 16:19:16 -0700 Subject: [PATCH 2/2] Address review feedback --- .../test/providers/qnn/einsum_op_test.cc | 217 +++++++++++++++--- 1 file changed, 184 insertions(+), 33 deletions(-) diff --git a/onnxruntime/test/providers/qnn/einsum_op_test.cc b/onnxruntime/test/providers/qnn/einsum_op_test.cc index 11606cd2b4c68..55412a7b15d98 100644 --- a/onnxruntime/test/providers/qnn/einsum_op_test.cc +++ b/onnxruntime/test/providers/qnn/einsum_op_test.cc @@ -15,33 +15,114 @@ namespace { +using onnxruntime::Node; +using onnxruntime::NodeArg; using onnxruntime::ProviderOptions; +using onnxruntime::test::AddQDQNodePair; +using onnxruntime::test::AddQDQNodePairWithOutputAsGraphOutput; using onnxruntime::test::BuildOpTestCase; using onnxruntime::test::ExpectedEPNodeAssignment; +using onnxruntime::test::GetTestInputQuantParams; +using onnxruntime::test::GetTestQDQModelFn; +using onnxruntime::test::MakeTestInput; +using onnxruntime::test::ModelTestBuilder; +using onnxruntime::test::QDQTolerance; +using onnxruntime::test::QuantParams; using onnxruntime::test::RunQnnModelTest; using onnxruntime::test::TestInputDef; +using onnxruntime::test::TestQDQModelAccuracy; using onnxruntime::utils::MakeAttribute; +constexpr char kEinsumOp[] = "Einsum"; +constexpr char kEinsumEquation[] = "equation"; +constexpr char kQnnBackendType[] = "backend_type"; +constexpr char kQnnBackendTypeCpu[] = "cpu"; +constexpr char kQnnBackendTypeHtp[] = "htp"; +constexpr char kOffloadGraphIoQuantization[] = "offload_graph_io_quantization"; +constexpr char kOffloadGraphIoQuantizationDisable[] = "0"; + template static void RunQnnEinsum( const std::string& backend, const TestInputDef& in0, const TestInputDef& in1, const std::string& equation, - const float f32_abs_err = 1e-4f) { + const float tolerance) { ProviderOptions provider_options; - provider_options["backend_type"] = backend; - provider_options["offload_graph_io_quantization"] = "0"; + provider_options[kQnnBackendType] = backend; + provider_options[kOffloadGraphIoQuantization] = kOffloadGraphIoQuantizationDisable; RunQnnModelTest( /*build_test_case=*/BuildOpTestCase( - /*op_type=*/"Einsum", + /*op_type=*/kEinsumOp, /*input_defs_1=*/{in0, in1}, /*input_defs_2=*/{}, - /*attrs=*/{MakeAttribute("equation", equation)}), + /*attrs=*/{MakeAttribute(kEinsumEquation, equation)}), /*provider_options=*/provider_options, - /*opset_version=*/13, + /*opset_version=*/12, /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, - /*f32_abs_err=*/f32_abs_err); + /*tolerance=*/tolerance); +} + +template +GetTestQDQModelFn BuildTestCaseQdq(const std::vector>& input_defs, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_defs, attrs, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + const size_t num_inputs = input_defs.size(); + + std::vector op_inputs; + op_inputs.reserve(num_inputs); + + // Process input 0 + NodeArg* input0 = MakeTestInput(builder, input_defs[0]); + QuantParams input0_qparams = GetTestInputQuantParams(input_defs[0]); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, input0_qparams.scale, + input0_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input0_after_qdq); + + // Process input 1 + NodeArg* input1 = MakeTestInput(builder, input_defs[1]); + QuantParams input1_qparams = GetTestInputQuantParams(input_defs[1]); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, + input1_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input1_after_qdq); + + // Op -> op_output + auto* output = builder.MakeIntermediate(); + Node& node = builder.AddNode(kEinsumOp, op_inputs, {output}); + for (const auto& attr : attrs) { + node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +template +static void RunQnnHtpQdqEinsum(const TestInputDef& in0, + const TestInputDef& in1, + const std::string& equation, + QDQTolerance tolerance) { + ProviderOptions provider_options; + provider_options[kQnnBackendType] = kQnnBackendTypeHtp; + provider_options[kOffloadGraphIoQuantization] = kOffloadGraphIoQuantizationDisable; + std::vector attrs{MakeAttribute(kEinsumEquation, equation)}; + auto f32_model_builder = BuildOpTestCase( + /*op_type=*/kEinsumOp, + /*input_defs_1=*/{in0, in1}, + /*input_defs_2=*/{}, + /*attrs=*/attrs); + auto qdq_model_builder = BuildTestCaseQdq( + /*input_defs=*/{in0, in1}, /*attrs=*/attrs, /*use_contrib_qdq=*/false); + TestQDQModelAccuracy(/*f32_model_fn=*/f32_model_builder, + /*qdq_model_fn=*/qdq_model_builder, + /*qnn_options=*/provider_options, + /*opset_version=*/12, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*tolerance=*/tolerance); } } // namespace @@ -59,10 +140,11 @@ TEST_F(QnnCPUBackendTests, EinsumRank2) { const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); RunQnnEinsum( - /*backend=*/"cpu", + /*backend=*/kQnnBackendTypeCpu, /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), - /*equation=*/"ab,bc->ac"); + /*equation=*/"ab,bc->ac", + /*tolerance=*/1e-4f); } TEST_F(QnnCPUBackendTests, EinsumRank4MatMul) { @@ -71,10 +153,11 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMul) { const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); RunQnnEinsum( - /*backend=*/"cpu", + /*backend=*/kQnnBackendTypeCpu, /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), - /*equation=*/"bhij,bhjd->bhid"); + /*equation=*/"bhij,bhjd->bhid", + /*tolerance=*/1e-4f); } TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeY) { @@ -83,10 +166,11 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeY) { const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); RunQnnEinsum( - /*backend=*/"cpu", + /*backend=*/kQnnBackendTypeCpu, /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), - /*equation=*/"bhid,bhjd->bhij"); + /*equation=*/"bhid,bhjd->bhij", + /*tolerance=*/1e-4f); } TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) { @@ -95,10 +179,11 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) { const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); RunQnnEinsum( - /*backend=*/"cpu", + /*backend=*/kQnnBackendTypeCpu, /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), - /*equation=*/"bchq,bkhc->bkhq"); + /*equation=*/"bchq,bkhc->bkhq", + /*tolerance=*/1e-4f); } TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) { @@ -107,82 +192,148 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) { const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); RunQnnEinsum( - /*backend=*/"cpu", + /*backend=*/kQnnBackendTypeCpu, /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), - /*equation=*/"bkhq,bchk->bchq"); + /*equation=*/"bkhq,bchk->bchq", + /*tolerance=*/1e-4f); } // -// QNN HTP +// QNN HTP F16 // #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -TEST_F(QnnHTPBackendTests, EinsumRank2MatMul) { +TEST_F(QnnHTPBackendTests, EinsumF16Rank2MatMul) { const std::vector shape0{2, 3}; const std::vector shape1{3, 4}; const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); RunQnnEinsum( - /*backend=*/"htp", + /*backend=*/kQnnBackendTypeHtp, /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), /*equation=*/"ij,jk->ik", - /*f32_abs_err=*/1e-2f); + /*tolerance=*/1e-2f); } -TEST_F(QnnHTPBackendTests, EinsumRank4MatMul) { +TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMul) { const std::vector shape0{3, 1, 5, 2}; const std::vector shape1{3, 1, 2, 5}; const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); RunQnnEinsum( - /*backend=*/"htp", + /*backend=*/kQnnBackendTypeHtp, /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), /*equation=*/"bhij,bhjd->bhid", - /*f32_abs_err=*/1e-2f); + /*tolerance=*/1e-2f); } -TEST_F(QnnHTPBackendTests, EinsumRank4MatMulTransposeY) { +TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeY) { const std::vector shape0{2, 3, 4, 2}; const std::vector shape1{2, 3, 5, 2}; const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); RunQnnEinsum( - /*backend=*/"htp", + /*backend=*/kQnnBackendTypeHtp, /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), /*equation=*/"bhid,bhjd->bhij", - /*f32_abs_err=*/1e-2f); + /*tolerance=*/1e-2f); } -TEST_F(QnnHTPBackendTests, EinsumRank4MatMulTransposeAll1) { +TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeAll1) { const std::vector shape0{1, 3, 1, 7}; const std::vector shape1{1, 7, 1, 3}; const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); RunQnnEinsum( - /*backend=*/"htp", + /*backend=*/kQnnBackendTypeHtp, /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), /*equation=*/"bchq,bkhc->bkhq", - /*f32_abs_err=*/1e-2f); + /*tolerance=*/1e-2f); } -TEST_F(QnnHTPBackendTests, EinsumRank4MatMulTransposeAll2) { +TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeAll2) { const std::vector shape0{1, 4, 1, 4}; const std::vector shape1{1, 9, 1, 4}; const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); RunQnnEinsum( - /*backend=*/"htp", + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bkhq,bchk->bchq", + /*tolerance=*/1e-2f); +} + +// +// QNN HTP QDQ +// + +TEST_F(QnnHTPBackendTests, EinsumQdqRank2MatMul) { + const std::vector shape0{2, 3}; + const std::vector shape1{3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"ij,jk->ik", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMul) { + const std::vector shape0{3, 1, 5, 2}; + const std::vector shape1{3, 1, 2, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhij,bhjd->bhid", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeY) { + const std::vector shape0{2, 3, 4, 2}; + const std::vector shape1{2, 3, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhid,bhjd->bhij", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeAll1) { + const std::vector shape0{1, 3, 1, 7}; + const std::vector shape1{1, 7, 1, 3}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bchq,bkhc->bkhq", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeAll2) { + const std::vector shape0{1, 4, 1, 4}; + const std::vector shape1{1, 9, 1, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), /*equation=*/"bkhq,bchk->bchq", - /*f32_abs_err=*/1e-2f); + /*tolerance=*/QDQTolerance()); } + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test