Skip to content

Commit dcf9126

Browse files
authored
[WebNN EP] Support GatherND and ScatterND op (#22181)
1 parent 975d3df commit dcf9126

File tree

7 files changed

+189
-6
lines changed

7 files changed

+189
-6
lines changed

js/web/docs/webnn-operators.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
3535
| Flatten | ai.onnx(7-8, 9-10, 11-12, 13-20, 21+) | reshape ||| |
3636
| Floor | ai.onnx(7-12, 13+) | floor ||| |
3737
| Gather | ai.onnx(7-10, 11-12, 13+) | gather ||| |
38+
| GatherND | ai.onnx(11, 12, 13+) | gatherND ||| Only supports 'batch_dims' == 0 |
3839
| Gelu | ai.onnx(20+) | gelu ||| |
3940
| Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm ||| Only supports 1-D 'C' input |
4041
| GlobalAveragePool | ai.onnx(7+) | averagePool2d ||| Only supports 4-D input |
@@ -79,6 +80,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
7980
| Relu | ai.onnx(7-12, 13, 14+) | relu ||| |
8081
| Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape ||| Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported |
8182
| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d ||| Only supports 4-D input, antialias == 0, coordinate_transformation_mode == 'half_pixel', exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant |
83+
| ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND ||| Only supports 'reduction' == 'none' |
8284
| Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice ||| |
8385
| Sigmoid | ai.onnx(7-12, 13+) | sigmoid ||| |
8486
| Softplus | ai.onnx(7+) | softplus ||| |

js/web/test/suite-test-list.jsonc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,9 +1777,9 @@
17771777
"test_gather_elements_0",
17781778
"test_gather_elements_1",
17791779
"test_gather_elements_negative_indices",
1780-
// "test_gathernd_example_float32",
1781-
// "test_gathernd_example_int32_batch_dim1",
1782-
// "test_gathernd_example_int32",
1780+
"test_gathernd_example_float32",
1781+
"test_gathernd_example_int32_batch_dim1",
1782+
"test_gathernd_example_int32",
17831783
"test_gemm_all_attributes",
17841784
"test_gemm_alpha",
17851785
"test_gemm_beta",
@@ -2260,9 +2260,9 @@
22602260
// // "test_scatter_elements_without_axis",
22612261
// // "test_scatter_with_axis",
22622262
// // "test_scatter_without_axis",
2263-
// // "test_scatternd_add",
2264-
// // "test_scatternd_multiply",
2265-
// // "test_scatternd",
2263+
"test_scatternd_add",
2264+
"test_scatternd_multiply",
2265+
"test_scatternd",
22662266
// // "test_sce_mean_3d_expanded",
22672267
// // "test_sce_mean_3d_log_prob_expanded",
22682268
// // "test_sce_mean_3d_log_prob",

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
215215
{"Flatten", "reshape"},
216216
{"Floor", "floor"},
217217
{"Gather", "gather"},
218+
{"GatherND", "gatherND"},
218219
{"Gelu", "gelu"},
219220
{"Gemm", "gemm"},
220221
{"GlobalAveragePool", "averagePool2d"},
@@ -260,6 +261,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
260261
{"Relu", "relu"},
261262
{"Reshape", "reshape"},
262263
{"Resize", "resample2d"},
264+
{"ScatterND", "scatterND"},
263265
{"Shape", "slice"},
264266
{"Sigmoid", "sigmoid"},
265267
{"Softplus", "softplus"},
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Copyright (c) Intel Corporation. All rights reserved.
3+
// Licensed under the MIT License.
4+
5+
#include "core/providers/common.h"
6+
#include "core/providers/shared/utils/utils.h"
7+
#include "core/providers/webnn/builders/helper.h"
8+
#include "core/providers/webnn/builders/model_builder.h"
9+
#include "core/providers/webnn/builders/op_builder_factory.h"
10+
11+
#include "base_op_builder.h"
12+
13+
namespace onnxruntime {
14+
namespace webnn {
15+
16+
class GatherNDOpBuilder : public BaseOpBuilder {
17+
// Add operator related.
18+
private:
19+
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
20+
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
21+
22+
// Operator support related.
23+
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
24+
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
25+
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
26+
const logging::Logger& logger) const override;
27+
};
28+
29+
// Add operator related.
30+
31+
Status GatherNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
32+
const logging::Logger& logger) const {
33+
const auto& input_defs = node.InputDefs();
34+
emscripten::val data = model_builder.GetOperand(input_defs[0]->Name());
35+
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
36+
emscripten::val options = emscripten::val::object();
37+
options.set("label", node.Name());
38+
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("gatherND", data, indices, options);
39+
40+
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
41+
return Status::OK();
42+
}
43+
44+
// Operator support related.
45+
46+
bool GatherNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
47+
const WebnnDeviceType /* device_type */,
48+
const logging::Logger& logger) const {
49+
NodeAttrHelper helper(node);
50+
if (helper.Get("batch_dims", 0) != 0) {
51+
LOGS(logger, VERBOSE) << "GatherND: WebNN only supports batch_dims 0 (default)";
52+
return false;
53+
}
54+
55+
return true;
56+
}
57+
58+
bool GatherNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
59+
const logging::Logger& logger) const {
60+
const auto& data = *node.InputDefs()[0];
61+
const auto& indices = *node.InputDefs()[1];
62+
const auto& op_type = node.OpType();
63+
64+
int32_t data_type;
65+
int32_t indices_type;
66+
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) {
67+
return false;
68+
}
69+
70+
return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
71+
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
72+
}
73+
74+
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
75+
op_registrations.builders.push_back(std::make_unique<GatherNDOpBuilder>());
76+
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
77+
}
78+
79+
} // namespace webnn
80+
} // namespace onnxruntime
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Copyright (c) Intel Corporation. All rights reserved.
3+
// Licensed under the MIT License.
4+
5+
#include "core/providers/common.h"
6+
#include "core/providers/shared/utils/utils.h"
7+
#include "core/providers/webnn/builders/helper.h"
8+
#include "core/providers/webnn/builders/model_builder.h"
9+
#include "core/providers/webnn/builders/op_builder_factory.h"
10+
11+
#include "base_op_builder.h"
12+
13+
namespace onnxruntime {
14+
namespace webnn {
15+
16+
class ScatterNDOpBuilder : public BaseOpBuilder {
17+
// Add operator related.
18+
private:
19+
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
20+
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
21+
22+
// Operator support related.
23+
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
24+
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
25+
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
26+
const logging::Logger& logger) const override;
27+
};
28+
29+
// Add operator related.
30+
31+
Status ScatterNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
32+
const logging::Logger& logger) const {
33+
const auto& input_defs = node.InputDefs();
34+
emscripten::val data = model_builder.GetOperand(input_defs[0]->Name());
35+
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
36+
emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name());
37+
emscripten::val options = emscripten::val::object();
38+
options.set("label", node.Name());
39+
emscripten::val output =
40+
model_builder.GetBuilder().call<emscripten::val>("scatterND", data, indices, updates, options);
41+
42+
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
43+
return Status::OK();
44+
}
45+
46+
// Operator support related.
47+
48+
bool ScatterNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
49+
const WebnnDeviceType /* device_type */,
50+
const logging::Logger& logger) const {
51+
NodeAttrHelper helper(node);
52+
if (helper.Get("reduction", "none") != "none") {
53+
LOGS(logger, VERBOSE) << "ScatterND: WebNN only supports reduction type none (default)";
54+
return false;
55+
}
56+
57+
return true;
58+
}
59+
60+
bool ScatterNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
61+
const logging::Logger& logger) const {
62+
const auto& data = *node.InputDefs()[0];
63+
const auto& indices = *node.InputDefs()[1];
64+
const auto& updates = *node.InputDefs()[2];
65+
const auto& op_type = node.OpType();
66+
67+
int32_t data_type;
68+
int32_t indices_type;
69+
int32_t updates_type;
70+
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger) ||
71+
!GetType(updates, updates_type, logger)) {
72+
return false;
73+
}
74+
75+
if (data_type != updates_type) {
76+
return false;
77+
}
78+
79+
return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
80+
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
81+
}
82+
83+
void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
84+
op_registrations.builders.push_back(std::make_unique<ScatterNDOpBuilder>());
85+
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
86+
}
87+
88+
} // namespace webnn
89+
} // namespace onnxruntime

onnxruntime/core/providers/webnn/builders/op_builder_factory.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
9898
CreateGatherOpBuilder("Gather", op_registrations);
9999
}
100100

101+
{ // GatherND
102+
CreateGatherNDOpBuilder("GatherND", op_registrations);
103+
}
104+
101105
{ // Flatten
102106
CreateFlattenOpBuilder("Flatten", op_registrations);
103107
}
@@ -170,6 +174,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
170174
CreateResizeOpBuilder("Resize", op_registrations);
171175
}
172176

177+
{ // ScatterND
178+
CreateScatterNDOpBuilder("ScatterND", op_registrations);
179+
}
180+
173181
{ // Shape
174182
CreateShapeOpBuilder("Shape", op_registrations);
175183
}

onnxruntime/core/providers/webnn/builders/op_builder_factory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderR
3131
void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3232
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3333
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
34+
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3435
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3536
void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3637
void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
@@ -43,6 +44,7 @@ void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r
4344
void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
4445
void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
4546
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
47+
void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
4648
void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
4749
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
4850
void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

0 commit comments

Comments
 (0)