-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[WebNN EP] Support Sign and CumSum operators #22616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Copyright (c) Intel Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/common/safeint.h" | ||
#include "core/framework/tensorprotoutils.h" | ||
#include "core/optimizer/initializer.h" | ||
#include "core/providers/common.h" | ||
#include "core/providers/shared/utils/utils.h" | ||
#include "core/providers/webnn/builders/helper.h" | ||
#include "core/providers/webnn/builders/model_builder.h" | ||
#include "core/providers/webnn/builders/op_builder_factory.h" | ||
|
||
#include "base_op_builder.h" | ||
Check warning on line 14 in onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc
|
||
|
||
namespace onnxruntime { | ||
namespace webnn { | ||
|
||
class CumSumOpBuilder : public BaseOpBuilder { | ||
// Add operator related. | ||
|
||
private: | ||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, | ||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT; | ||
|
||
// Operator support related. | ||
private: | ||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, | ||
const logging::Logger& logger) const override; | ||
}; | ||
|
||
// Add operator related. | ||
Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, | ||
const Node& node, | ||
Check warning on line 34 in onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc
|
||
const logging::Logger& logger) const { | ||
Check warning on line 35 in onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc
|
||
const auto& input_defs = node.InputDefs(); | ||
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); | ||
std::vector<int64_t> input_shape; | ||
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); | ||
const auto input_rank = input_shape.size(); | ||
|
||
NodeAttrHelper helper(node); | ||
int64_t axis = helper.Get("axis", 0); | ||
axis = HandleNegativeAxis(axis, input_rank); | ||
|
||
const auto exclusive = helper.Get("exclusive", 0); | ||
const auto reverse = helper.Get("reverse", 0); | ||
|
||
emscripten::val options = emscripten::val::object(); | ||
options.set("exclusive", exclusive == 1); | ||
options.set("reversed", reverse == 1); | ||
options.set("label", node.Name()); | ||
|
||
emscripten::val output = emscripten::val::object(); | ||
output = model_builder.GetBuilder().call<emscripten::val>("cumulativeSum", input, narrow<uint32_t>(axis), options); | ||
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); | ||
Check warning on line 56 in onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc
|
||
return Status::OK(); | ||
} | ||
|
||
// Operator support related. | ||
bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, | ||
const Node& node, | ||
Check warning on line 62 in onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc
|
||
const logging::Logger& logger) const { | ||
Check warning on line 63 in onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc
|
||
const auto& input_defs = node.InputDefs(); | ||
|
||
std::vector<int64_t> input_shape; | ||
Check warning on line 66 in onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc
|
||
if (!GetShape(*input_defs[0], input_shape, logger)) | ||
return false; | ||
|
||
return true; | ||
} | ||
|
||
void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { | ||
Check warning on line 73 in onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc
|
||
op_registrations.builders.push_back(std::make_unique<CumSumOpBuilder>()); | ||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); | ||
} | ||
|
||
} // namespace webnn | ||
} // namespace onnxruntime |
Uh oh!
There was an error while loading. Please reload this page.