Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,30 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
}
}

bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
const auto& node_arg_name = node_arg.Name();
const auto* shape_proto = node_arg.Shape();
bool IsInputSupported(const NodeArg& input, const std::string& parent_name, const logging::Logger& logger) {
const auto& input_name = input.Name();
const auto* shape_proto = input.Shape();
// Optional tensors can be indicated by an empty name, just ignore it.
if (node_arg_name.empty()) {
if (input_name.empty()) {
return true;
}
// We do not support input/output with no shape.
// We do not support input with no shape.
if (!shape_proto) {
LOGS(logger, VERBOSE) << "Node arg [" << node_arg_name << "] of [" << parent_name << "] has not shape";
LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name
<< "] has not shape";
return false;
}

for (const auto& dim : shape_proto->dim()) {
// WebNN doesn't support dynamic shape - use sessionOptions.freeDimensionOverrides to fix the shape.
if (!dim.has_dim_value()) {
LOGS(logger, VERBOSE) << "Dynamic shape is not supported, "
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name;
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape for input: "
<< input_name;
return false;
}
if (dim.dim_value() == 0) {
LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN";
LOGS(logger, VERBOSE) << "The shape of [" << input_name << "] has 0 dimension which is not supported by WebNN";
return false;
}
}
Expand All @@ -106,12 +108,7 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
std::vector<std::vector<size_t>> supported_node_groups;

for (const auto* input : graph_viewer.GetInputs()) {
if (!IsTensorShapeSupported(*input, "graph", logger)) {
return supported_node_groups;
}
}
for (const auto* output : graph_viewer.GetOutputs()) {
if (!IsTensorShapeSupported(*output, "graph", logger)) {
if (!IsInputSupported(*input, "graph", logger)) {
return supported_node_groups;
}
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
}

bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);

// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
Expand Down
16 changes: 2 additions & 14 deletions onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
if (!HasSupportedInputs(node, wnn_limits, logger))
return false;

if (!HasSupportedOutputs(node, wnn_limits, logger))
if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
return false;

if (!HasSupportedOpSet(node, logger))
Expand All @@ -47,7 +47,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
if (!IsTensorShapeSupported(*input, node_name, logger)) {
if (!IsInputSupported(*input, node_name, logger)) {
return false;
}
}
Expand All @@ -68,18 +68,6 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
}

bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* output : node.OutputDefs()) {
if (!IsTensorShapeSupported(*output, node_name, logger)) {
return false;
}
}

return HasSupportedOutputsImpl(node, wnn_limits, logger);
}

bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class BaseOpBuilder : public IOpBuilder {
private:
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
};

} // namespace webnn
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
if (!shape.empty()) {
dims.reserve(shape.size());
for (const auto& dim : shape) {
// dim_param free dimensions should have already been excluded by IsTensorShapeSupported().
// dim_param free dimensions should have already been excluded by IsInputSupported().
assert(dim.has_dim_value());
dims.push_back(SafeInt<int32_t>(dim.dim_value()));
}
Expand Down
Loading