Skip to content

Commit 39d1f6f

Browse files
committed
Add gatherElements, scatterElements
1 parent 71625a0 commit 39d1f6f

File tree

7 files changed

+189
-4
lines changed

7 files changed

+189
-4
lines changed

js/web/docs/webnn-operators.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
3535
| Flatten | ai.onnx(7-8, 9-10, 11-12, 13-20, 21+) | reshape ||| |
3636
| Floor | ai.onnx(7-12, 13+) | floor ||| |
3737
| Gather | ai.onnx(7-10, 11-12, 13+) | gather ||| |
38+
| GatherElements | ai.onnx(11-12, 13+) | gatherElements ||| |
3839
| GatherND | ai.onnx(11, 12, 13+) | gatherND ||| Only supports 'batch_dims' == 0 |
3940
| Gelu | ai.onnx(20+) | gelu ||| |
4041
| Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm ||| Only supports 1-D 'C' input |
@@ -80,6 +81,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
8081
| Relu | ai.onnx(7-12, 13, 14+) | relu ||| |
8182
| Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape ||| Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported |
8283
| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d ||| Only supports 4-D input, exclude_outside != 0, input 'scales' and 'sizes' if present must be a constant, 'linear' and 'nearest' modes |
84+
| ScatterElements | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterElements ||| Only supports 'reduction' == 'none' |
8385
| ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND ||| Only supports 'reduction' == 'none' |
8486
| Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice ||| |
8587
| Sigmoid | ai.onnx(7-12, 13+) | sigmoid ||| |

js/web/test/suite-test-list.jsonc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2254,10 +2254,10 @@
22542254
// // "test_round",
22552255
// // "test_scan_sum",
22562256
// // "test_scan9_sum",
2257-
// // "test_scatter_elements_with_axis",
2258-
// // "test_scatter_elements_with_duplicate_indices",
2259-
// // "test_scatter_elements_with_negative_indices",
2260-
// // "test_scatter_elements_without_axis",
2257+
"test_scatter_elements_with_axis",
2258+
"test_scatter_elements_with_duplicate_indices",
2259+
"test_scatter_elements_with_negative_indices",
2260+
"test_scatter_elements_without_axis",
22612261
// // "test_scatter_with_axis",
22622262
// // "test_scatter_without_axis",
22632263
"test_scatternd_add",

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
179179
{"Flatten", "reshape"},
180180
{"Floor", "floor"},
181181
{"Gather", "gather"},
182+
{"GatherElements", "gatherElements"},
182183
{"GatherND", "gatherND"},
183184
{"Gelu", "gelu"},
184185
{"Gemm", "gemm"},
@@ -225,6 +226,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
225226
{"Relu", "relu"},
226227
{"Reshape", "reshape"},
227228
{"Resize", "resample2d"},
229+
{"ScatterElements", "scatterElements"},
228230
{"ScatterND", "scatterND"},
229231
{"Shape", "slice"},
230232
{"Sigmoid", "sigmoid"},
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Copyright (c) Intel Corporation. All rights reserved.
3+
// Licensed under the MIT License.
4+
5+
#include "core/providers/common.h"
6+
#include "core/providers/shared/utils/utils.h"
7+
#include "core/providers/webnn/builders/helper.h"
8+
#include "core/providers/webnn/builders/model_builder.h"
9+
#include "core/providers/webnn/builders/op_builder_factory.h"
10+
11+
#include "base_op_builder.h"
12+
13+
namespace onnxruntime {
14+
namespace webnn {
15+
16+
class GatherElementsOpBuilder : public BaseOpBuilder {
17+
// Add operator related.
18+
private:
19+
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
20+
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
21+
22+
// Operator support related.
23+
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
24+
const logging::Logger& logger) const override;
25+
};
26+
27+
// Add operator related.
28+
29+
Status GatherElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
30+
const logging::Logger& logger) const {
31+
const auto& input_defs = node.InputDefs();
32+
emscripten::val data = model_builder.GetOperand(input_defs[0]->Name());
33+
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
34+
emscripten::val options = emscripten::val::object();
35+
options.set("label", node.Name());
36+
37+
std::vector<int64_t> input_shape;
38+
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
39+
const size_t rank = input_shape.size();
40+
NodeAttrHelper helper(node);
41+
const uint32_t axis = static_cast<uint32_t>(HandleNegativeAxis(helper.Get("axis", 0), rank));
42+
options.set("axis", axis);
43+
44+
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("gatherElements", data, indices, options);
45+
46+
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
47+
return Status::OK();
48+
}
49+
50+
// Operator support related.
51+
52+
bool GatherElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
53+
const logging::Logger& logger) const {
54+
const auto& data = *node.InputDefs()[0];
55+
const auto& indices = *node.InputDefs()[1];
56+
const auto& op_type = node.OpType();
57+
58+
int32_t data_type;
59+
int32_t indices_type;
60+
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) {
61+
return false;
62+
}
63+
64+
return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
65+
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
66+
}
67+
68+
void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
69+
op_registrations.builders.push_back(std::make_unique<GatherElementsOpBuilder>());
70+
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
71+
}
72+
73+
} // namespace webnn
74+
} // namespace onnxruntime
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Copyright (c) Intel Corporation. All rights reserved.
3+
// Licensed under the MIT License.
4+
5+
#include "core/providers/common.h"
6+
#include "core/providers/shared/utils/utils.h"
7+
#include "core/providers/webnn/builders/helper.h"
8+
#include "core/providers/webnn/builders/model_builder.h"
9+
#include "core/providers/webnn/builders/op_builder_factory.h"
10+
11+
#include "base_op_builder.h"
12+
13+
namespace onnxruntime {
14+
namespace webnn {
15+
16+
class ScatterElementsOpBuilder : public BaseOpBuilder {
17+
// Add operator related.
18+
private:
19+
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
20+
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
21+
22+
// Operator support related.
23+
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
24+
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
25+
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
26+
const logging::Logger& logger) const override;
27+
};
28+
29+
// Add operator related.
30+
31+
Status ScatterElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
32+
const logging::Logger& logger) const {
33+
const auto& input_defs = node.InputDefs();
34+
emscripten::val data = model_builder.GetOperand(input_defs[0]->Name());
35+
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
36+
emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name());
37+
emscripten::val options = emscripten::val::object();
38+
options.set("label", node.Name());
39+
40+
std::vector<int64_t> input_shape;
41+
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
42+
const size_t rank = input_shape.size();
43+
NodeAttrHelper helper(node);
44+
const uint32_t axis = static_cast<uint32_t>(HandleNegativeAxis(helper.Get("axis", 0), rank));
45+
options.set("axis", axis);
46+
47+
emscripten::val output =
48+
model_builder.GetBuilder().call<emscripten::val>("scatterElements", data, indices, updates, options);
49+
50+
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
51+
return Status::OK();
52+
}
53+
54+
// Operator support related.
55+
56+
bool ScatterElementsOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
57+
const WebnnDeviceType /* device_type */,
58+
const logging::Logger& logger) const {
59+
NodeAttrHelper helper(node);
60+
if (helper.Get("reduction", "none") != "none") {
61+
LOGS(logger, VERBOSE) << "ScatterElements: the type of reduction must be none (default)";
62+
return false;
63+
}
64+
65+
return true;
66+
}
67+
68+
bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
69+
const logging::Logger& logger) const {
70+
const auto& data = *node.InputDefs()[0];
71+
const auto& indices = *node.InputDefs()[1];
72+
const auto& updates = *node.InputDefs()[2];
73+
const auto& op_type = node.OpType();
74+
75+
int32_t data_type;
76+
int32_t indices_type;
77+
int32_t updates_type;
78+
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger) ||
79+
!GetType(updates, updates_type, logger)) {
80+
return false;
81+
}
82+
83+
if (data_type != updates_type) {
84+
return false;
85+
}
86+
87+
return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
88+
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
89+
}
90+
91+
void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
92+
op_registrations.builders.push_back(std::make_unique<ScatterElementsOpBuilder>());
93+
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
94+
}
95+
96+
} // namespace webnn
97+
} // namespace onnxruntime

onnxruntime/core/providers/webnn/builders/op_builder_factory.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
9898
CreateGatherOpBuilder("Gather", op_registrations);
9999
}
100100

101+
{ // GatherElements
102+
CreateGatherElementsOpBuilder("GatherElements", op_registrations);
103+
}
104+
101105
{ // GatherND
102106
CreateGatherNDOpBuilder("GatherND", op_registrations);
103107
}
@@ -174,6 +178,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
174178
CreateResizeOpBuilder("Resize", op_registrations);
175179
}
176180

181+
{ // ScatterElements
182+
CreateScatterElementsOpBuilder("ScatterElements", op_registrations);
183+
}
184+
177185
{ // ScatterND
178186
CreateScatterNDOpBuilder("ScatterND", op_registrations);
179187
}

onnxruntime/core/providers/webnn/builders/op_builder_factory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderR
3131
void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3232
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3333
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
34+
void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3435
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3536
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3637
void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
@@ -44,6 +45,7 @@ void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r
4445
void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
4546
void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
4647
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
48+
void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
4749
void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
4850
void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
4951
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

0 commit comments

Comments
 (0)