Skip to content

Commit 8cf28ba

Browse files
miaobinankitm3k
authored andcommitted
[WebNN EP] Support Sign and CumSum operators (microsoft#22616)
This PR supports Sign and CumSum operators for WebNN EP. @Honry @fdwr PTAL, thanks.
1 parent 862a94c commit 8cf28ba

File tree

4 files changed

+9
-30
lines changed

4 files changed

+9
-30
lines changed

js/web/docs/webnn-operators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
2525
| Conv | ai.onnx(7-10, 11+) | conv2d ||| Only supports 3-D or 4-D input and 'W' (weight) |
2626
| ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d ||| Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group |
2727
| Cos | ai.onnx(7+) | cos ||| |
28-
| CumSum | ai.onnx(11-13, 14+) | cumulativeSum ||| 'axis' input should be a constant |
28+
| CumSum | ai.onnx(11-13, 14+) | cumulativeSum ||| |
2929
| Div | ai.onnx(7-12, 13, 14+) | div ||| |
3030
| DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear ||| |
3131
| Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity ||| Only supports test mode |

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ static const InlinedHashMap<std::string, std::string> op_map = {
273273
{"Shape", "slice"},
274274
{"Sigmoid", "sigmoid"},
275275
{"Sign", "sign"},
276-
{"SimplifiedLayerNormalization", "layerNormalization"},
277276
{"Softplus", "softplus"},
278277
{"Softsign", "softsign"},
279278
{"Sin", "sin"},

onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ namespace webnn {
1919
class CumSumOpBuilder : public BaseOpBuilder {
2020
// Add operator related.
2121

22-
public:
23-
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
24-
2522
private:
2623
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
2724
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
@@ -33,28 +30,19 @@ class CumSumOpBuilder : public BaseOpBuilder {
3330
};
3431

3532
// Add operator related.
36-
37-
void CumSumOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
38-
// Skip axis.
39-
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
40-
}
41-
42-
Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
33+
Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
34+
const Node& node,
4335
const logging::Logger& logger) const {
4436
const auto& input_defs = node.InputDefs();
4537
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
4638
std::vector<int64_t> input_shape;
4739
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");
4840
const auto input_rank = input_shape.size();
4941

50-
const auto& initializers = model_builder.GetInitializerTensors();
51-
const std::string axis_name = GetTensorName(input_defs, 1);
52-
const auto axis_tensor = *initializers.at(axis_name);
53-
emscripten::val axis = emscripten::val::undefined();
54-
ORT_RETURN_IF_NOT(ReadScalarTensorData(axis_tensor, axis, logger), "Cannot get axis value");
55-
int64_t webnn_axis = HandleNegativeAxis(axis.as<int64_t>(), input_rank);
56-
5742
NodeAttrHelper helper(node);
43+
int64_t axis = helper.Get("axis", 0);
44+
axis = HandleNegativeAxis(axis, input_rank);
45+
5846
const auto exclusive = helper.Get("exclusive", 0);
5947
const auto reverse = helper.Get("reverse", 0);
6048

@@ -64,14 +52,13 @@ Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
6452
options.set("label", node.Name());
6553

6654
emscripten::val output = emscripten::val::object();
67-
output = model_builder.GetBuilder().call<emscripten::val>("cumulativeSum", input, gsl::narrow<uint32_t>(webnn_axis),
68-
options);
55+
output = model_builder.GetBuilder().call<emscripten::val>("cumulativeSum", input, gsl::narrow<uint32_t>(axis), options);
6956
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
7057
return Status::OK();
7158
}
7259

7360
// Operator support related.
74-
bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
61+
bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
7562
const Node& node,
7663
WebnnDeviceType /* device_type */,
7764
const logging::Logger& logger) const {
@@ -81,13 +68,6 @@ bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
8168
if (!GetShape(*input_defs[0], input_shape, logger))
8269
return false;
8370

84-
const std::string axis_name = GetTensorName(input_defs, 1);
85-
// Inputs contain optional 'axis' input.
86-
if (!Contains(initializers, axis_name)) {
87-
LOGS(logger, VERBOSE) << "The axis must be a constant initializer.";
88-
return false;
89-
}
90-
9171
return true;
9272
}
9373

onnxruntime/core/providers/webnn/builders/op_builder_factory.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
8282
}
8383

8484
{ // CumSum
85-
CreateCumSumOpBuilder("CumSum", op_registrations);
85+
CreateConcatOpBuilder("CumSum", op_registrations);
8686
}
8787

8888
{ // Dropout

0 commit comments

Comments
 (0)