Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 75 additions & 51 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,69 +99,93 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n
return true;
}

// Check if all input tensor ranks of the given node are supported by WebNN.
bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) {
const std::string_view op_type = node.OpType();
const auto it = op_inputs_map.find(op_type);
if (it == op_inputs_map.end()) {
LOGS(logger, VERBOSE) << "Operator type: [" << op_type << "] is not found in the op inputs map.";
// Check if a single input's rank of an ONNX op is supported by corresponding WebNN op.
bool IsInputRankSupported(const emscripten::val& wnn_limits,
const std::string_view webnn_op_type,
const std::string_view input_name,
const size_t input_rank,
const std::string_view node_name,
const logging::Logger& logger) {
const std::string webnn_op_type_str(webnn_op_type);
const std::string input_name_str(input_name);

if (wnn_limits[webnn_op_type_str].isUndefined()) {
LOGS(logger, VERBOSE) << "WebNN op type: [" << webnn_op_type
<< "] is not defined in WebNN MLOpSupportLimits.";
return false;
}

const auto& input_defs = node.InputDefs();
const std::string_view webnn_op_type = it->second.opType;
const std::string webnn_op_type_str(webnn_op_type);
const emscripten::val input_limits = wnn_limits[webnn_op_type_str][input_name_str];

for (const auto& input : it->second.inputs) {
if (static_cast<size_t>(input.index) >= input_defs.size() || input_defs[input.index] == nullptr) {
LOGS(logger, VERBOSE) << "Input index [" << input.index
<< "] for operator type [" << op_type
<< "], corresponding WebNN op type [" << webnn_op_type
<< "], WebNN input name [" << input.name
<< "] is invalid.";
return false;
}
if (input_limits.isUndefined()) {
LOGS(logger, VERBOSE) << "Node name: [" << node_name
<< "], WebNN op type: [" << webnn_op_type
<< "], input [" << input_name
<< "]: limits are not defined in WebNN MLOpSupportLimits.";
return false;
}

std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[input.index], input_shape, logger)) {
return false;
}
const emscripten::val rank_range = input_limits["rankRange"];
if (rank_range.isUndefined()) {
LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type
<< "] input [" << input_name
<< "]: missing 'rankRange' attribute.";
return false;
}

const std::string input_name_str(input.name);
if (wnn_limits[webnn_op_type_str].isUndefined() ||
wnn_limits[webnn_op_type_str][input_name_str].isUndefined()) {
LOGS(logger, VERBOSE) << "Operator type: [" << op_type
<< "], input index: [" << input.index
<< "], corresponding WebNN op type: " << webnn_op_type
<< ", WebNN input name " << input.name
<< " is not defined in wnn_limits.";
return false;
}
const emscripten::val min_val = rank_range["min"];
const emscripten::val max_val = rank_range["max"];
if (min_val.isUndefined() || max_val.isUndefined()) {
LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type
<< "] input [" << input_name
<< "]: its 'rankRange' limits is missing valid 'min' or 'max' attributes.";
return false;
}

const auto& input_limits = wnn_limits[webnn_op_type_str][input_name_str];
if (input_limits["rankRange"].isUndefined()) {
LOGS(logger, VERBOSE) << "Operator type: [" << op_type
<< "], input index: [" << input.index
<< "], corresponding WebNN op type: " << webnn_op_type
<< ", WebNN input name " << input.name
<< "'s rankRange is not defined.";
return false;
size_t min_rank = min_val.as<size_t>();
size_t max_rank = max_val.as<size_t>();
if (input_rank < min_rank || input_rank > max_rank) {
LOGS(logger, VERBOSE) << "Node name: [" << node_name
<< "] WebNN op type [" << webnn_op_type
<< "] input [" << input_name << "] rank " << input_rank
<< " is not in supported range [" << min_rank << ", " << max_rank << "]";
return false;
}

return true;
}

bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) {
const std::string_view onnx_op_type = node.OpType();
const std::string_view webnn_op_type = GetWebNNOpType(onnx_op_type);

if (webnn_op_type.empty()) {
LOGS(logger, VERBOSE) << "ONNX op type: [" << onnx_op_type << "]'s corresponding WebNN op is not found.";
return false;
}

std::vector<InputInfo> inputs;
if (!GetWebNNOpInputs(onnx_op_type, inputs, logger)) {
return false;
}

const auto& input_defs = node.InputDefs();

for (const auto& input : inputs) {
// If it is an optional input and is absent, skip.
if (!TensorExists(input_defs, input.index)) {
continue;
}

int input_dim_size = static_cast<int>(input_shape.size());
int min_rank = input_limits["rankRange"]["min"].as<int>();
int max_rank = input_limits["rankRange"]["max"].as<int>();

if (input_dim_size < min_rank || input_dim_size > max_rank) {
LOGS(logger, VERBOSE) << "Operator type: [" << op_type
<< "], input index: [" << input.index
<< "], corresponding WebNN op type: " << webnn_op_type
<< ", WebNN input name: " << input.name
<< ", input size " << input_dim_size
<< " is not in supported range [" << min_rank << ", " << max_rank << "]";
std::vector<int64_t> shape;
if (!GetShape(*input_defs[input.index], shape, logger) ||
!IsInputRankSupported(wnn_limits, webnn_op_type, input.name,
shape.size(),
node.Name(), logger)) {
return false;
}
}

return true;
}

Expand Down
34 changes: 34 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n

bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger);

bool IsInputRankSupported(const emscripten::val& wnn_limits,
const std::string_view webnn_op_type,
const std::string_view input_name,
const size_t input_rank,
const std::string_view node_name,
const logging::Logger& logger);

// Get a set of nodes supported by WebNN EP.
std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder,
Expand Down Expand Up @@ -244,6 +251,33 @@ inline std::string_view GetWebNNOpType(const std::string_view onnx_op_type) {
return (it != op_inputs_map.end()) ? it->second.opType : "";
}

// Get corresponding input name of WebNN op type by ONNX op type from op_input_map
inline std::string_view GetWebNNInputName(const std::string_view onnx_op_type, const int input_index) {
const auto it = op_inputs_map.find(onnx_op_type);

if (it != op_inputs_map.end()) {
for (const auto& input : it->second.inputs) {
if (input.index == input_index) {
return input.name;
}
}
}

return "";
}

inline bool GetWebNNOpInputs(const std::string_view onnx_op_type,
std::vector<InputInfo>& inputs,
const logging::Logger& logger) {
const auto it = op_inputs_map.find(onnx_op_type);
if (it == op_inputs_map.end()) {
LOGS(logger, VERBOSE) << "WebNN op inputs not found for op type: " << onnx_op_type;
return false;
}
inputs = it->second.inputs;
return true;
}

bool AreDataTypesSame(const std::string_view op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ bool ConcatOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod
}
}

return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger);
return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger);
}

void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,12 @@ bool GatherElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const N
const auto& indices = *node.InputDefs()[1];
const std::string_view op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
int32_t data_type, indices_type;
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) {
return false;
}

return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,12 @@ bool GatherNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& n
const auto& indices = *node.InputDefs()[1];
const std::string_view op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
int32_t data_type, indices_type;
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) {
return false;
}

return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ bool GatherOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod
const auto& input = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const std::string_view op_type = node.OpType();
int32_t input_type;
int32_t indices_type;
int32_t input_type, indices_type;

if (!GetType(input, input_type, logger) ||
!GetType(indices, indices_type, logger))
return false;

return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) &&
return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) &&
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
}

Expand Down
44 changes: 39 additions & 5 deletions onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
std::vector<int64_t> a_zero_point_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[2], a_zero_point_shape, logger), "Cannot get shape of a_zero_point");
// Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to deafult value 1.0f.
// Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to default value 1.0f.
// The scale input should have the same shape as the zero point input.
a_scale = model_builder.CreateOrGetConstant<float>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
1.0f,
Expand Down Expand Up @@ -268,11 +268,45 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node,
return false;
}

if (op_type == "MatMulInteger") {
// The first decomposed op of MatMulInteger is DequantizeLinear, and so
// we only need to ensure it supports the input0_type.
if (op_type == "Gemm") {
return IsInputRankSupportedByOp(node, wnn_limits, logger) &&
IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger);
} else if (op_type == "MatMulInteger") {
// Check up to 4 inputs for MatMulInteger
for (size_t i = 0; i < input_defs.size(); ++i) {
std::vector<int64_t> shape;
if (!GetShape(*input_defs[i], shape, logger)) {
return false;
}

// We made workaround to support 1D for input A and B, skip further checks if they are 1D
if (i <= 1 && shape.size() == 1) {
continue;
}

// For DequantizeLinear, input indices: 0 (x), 1 (scale), 2 (zero_point)
if (!IsInputRankSupported(wnn_limits, "dequantizeLinear",
(i < 2) ? "input" : "zeroPoint",
shape.size(), node.Name(), logger)) {
return false;
}
}
return IsDataTypeSupportedByOp("DequantizeLinear", input0_type, wnn_limits, "input", "x", logger);
} else {
} else { // MatMul
for (int i = 0; i < 2; ++i) {
std::vector<int64_t> shape;
if (!GetShape(*input_defs[i], shape, logger)) {
return false;
}

if (shape.size() == 1) {
continue;
}

if (!IsInputRankSupported(wnn_limits, "matmul", (i == 0) ? "a" : "b", shape.size(), node.Name(), logger)) {
return false;
}
}
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ bool GruOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node,
return false;
}

return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger);
return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger);
}

bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ bool LogicalOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no
}

std::string onnx_input_name = op_type == "Not" ? "X" : "A";
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger);
return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger);
}

void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ bool LstmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node,
return false;
}

return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
}

bool LstmOpBuilder::HasSupportedOutputsImpl(const Node& node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod
}
}

return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger);
return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger);
}

void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ bool QDQOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node,
return false;
}

return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) &&
return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) &&
IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "scale", "x_scale", logger) &&
(!has_input2 || IsDataTypeSupportedByOp(op_type, input2_type, wnn_limits, "zeroPoint", "x_zero_point", logger));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& updates = *node.InputDefs()[2];
const std::string_view op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
Expand All @@ -85,7 +84,9 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const
return false;
}

return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
const std::string_view op_type = node.OpType();

return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node&
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& updates = *node.InputDefs()[2];
const std::string_view op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
Expand All @@ -76,8 +75,8 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node&
if (data_type != updates_type) {
return false;
}

return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
const std::string_view op_type = node.OpType();
return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
}

Expand Down
Loading
Loading