Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Flatten | ai.onnx(7-8, 9-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | |
| Floor | ai.onnx(7-12, 13+) | floor | ✓ | ✓ | |
| Gather | ai.onnx(7-10, 11-12, 13+) | gather | ✓ | ✓ | |
| GatherElements | ai.onnx(11-12, 13+) | gatherElements | ✗ | ✓ | |
| GatherND | ai.onnx(11, 12, 13+) | gatherND | ✓ | ✓ | Only supports 'batch_dims' == 0 |
| Gelu | ai.onnx(20+) | gelu | ✓ | ✓ | |
| Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm | ✓ | ✓ | Only supports 1-D 'C' input |
Expand Down Expand Up @@ -80,6 +81,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Relu | ai.onnx(7-12, 13, 14+) | relu | ✓ | ✓ | |
| Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | ✓ | ✓ | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported |
| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, antialias == 0, coordinate_transformation_mode == 'half_pixel', exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant |
| ScatterElements | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterElements | ✗ | ✓ | Only supports 'reduction' == 'none' |
| ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND | ✗ | ✓ | Only supports 'reduction' == 'none' |
| Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | |
| Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | |
Expand Down
8 changes: 4 additions & 4 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -2254,10 +2254,10 @@
// // "test_round",
// // "test_scan_sum",
// // "test_scan9_sum",
// // "test_scatter_elements_with_axis",
// // "test_scatter_elements_with_duplicate_indices",
// // "test_scatter_elements_with_negative_indices",
// // "test_scatter_elements_without_axis",
"test_scatter_elements_with_axis",
"test_scatter_elements_with_duplicate_indices",
"test_scatter_elements_with_negative_indices",
"test_scatter_elements_without_axis",
// // "test_scatter_with_axis",
// // "test_scatter_without_axis",
"test_scatternd_add",
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Flatten", "reshape"},
{"Floor", "floor"},
{"Gather", "gather"},
{"GatherElements", "gatherElements"},
{"GatherND", "gatherND"},
{"Gelu", "gelu"},
{"Gemm", "gemm"},
Expand Down Expand Up @@ -261,6 +262,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Relu", "relu"},
{"Reshape", "reshape"},
{"Resize", "resample2d"},
{"ScatterElements", "scatterElements"},
{"ScatterND", "scatterND"},
{"Shape", "slice"},
{"Sigmoid", "sigmoid"},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"

#include "base_op_builder.h"

Check warning on line 11 in onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc:11: Include the directory when naming header files [build/include_subdir] [4]

namespace onnxruntime {
namespace webnn {

class GatherElementsOpBuilder : public BaseOpBuilder {
// Add operator related.
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

// Add operator related.

Status GatherElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const {

Check warning on line 30 in onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc:30: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const auto& input_defs = node.InputDefs();
emscripten::val data = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());

std::vector<int64_t> input_shape;

Check warning on line 37 in onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc:37: Add #include <vector> for vector<> [build/include_what_you_use] [4]
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
const size_t rank = input_shape.size();
NodeAttrHelper helper(node);
const uint32_t axis = static_cast<uint32_t>(HandleNegativeAxis(helper.Get("axis", 0), rank));
options.set("axis", axis);

emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("gatherElements", data, indices, options);

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));

Check warning on line 46 in onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc:46: Add #include <utility> for move [build/include_what_you_use] [4]
return Status::OK();
}

// Operator support related.

bool GatherElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {

Check warning on line 53 in onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc:53: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) {
return false;
}

return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
}

void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

Check warning on line 68 in onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc:68: Add #include <string> for string [build/include_what_you_use] [4]
op_registrations.builders.push_back(std::make_unique<GatherElementsOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace webnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"

#include "base_op_builder.h"

Check warning on line 11 in onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc:11: Include the directory when naming header files [build/include_subdir] [4]

namespace onnxruntime {
namespace webnn {

class ScatterElementsOpBuilder : public BaseOpBuilder {
// Add operator related.
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

// Add operator related.

Status ScatterElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const {

Check warning on line 32 in onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc:32: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const auto& input_defs = node.InputDefs();
emscripten::val data = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name());
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());

std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
const size_t rank = input_shape.size();
NodeAttrHelper helper(node);
const uint32_t axis = static_cast<uint32_t>(HandleNegativeAxis(helper.Get("axis", 0), rank));
options.set("axis", axis);

emscripten::val output =
model_builder.GetBuilder().call<emscripten::val>("scatterElements", data, indices, updates, options);

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}

// Operator support related.

bool ScatterElementsOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */,

Check warning on line 57 in onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc:57: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const logging::Logger& logger) const {

Check warning on line 58 in onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc:58: Do not indent within a namespace. [whitespace/indent_namespace] [4]
NodeAttrHelper helper(node);
if (helper.Get("reduction", "none") != "none") {
LOGS(logger, VERBOSE) << "ScatterElements: WebNN only supports reduction type none (default)";
return false;
}

return true;
}

bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& updates = *node.InputDefs()[2];
const auto& op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
int32_t updates_type;
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger) ||
!GetType(updates, updates_type, logger)) {
return false;
}

if (data_type != updates_type) {
return false;
}

return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
}

void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<ScatterElementsOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace webnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateGatherOpBuilder("Gather", op_registrations);
}

{ // GatherElements
CreateGatherElementsOpBuilder("GatherElements", op_registrations);
}

{ // GatherND
CreateGatherNDOpBuilder("GatherND", op_registrations);
}
Expand Down Expand Up @@ -174,6 +178,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateResizeOpBuilder("Resize", op_registrations);
}

{ // ScatterElements
CreateScatterElementsOpBuilder("ScatterElements", op_registrations);
}

{ // ScatterND
CreateScatterNDOpBuilder("ScatterND", op_registrations);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderR
void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand All @@ -44,6 +45,7 @@ void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r
void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand Down
Loading