Skip to content

Commit 2c22283

Browse files
committed
fix rebase compilation errors
* pass WebNNGraph receiver to ContextImplOrt::CreateGraphImpl() * remove instanceNorm layout check * add BatchNormalizationAxis in context properties
1 parent 53f2842 commit 2c22283

File tree

6 files changed

+23
-23
lines changed

6 files changed

+23
-23
lines changed

services/webnn/ort/context_impl_ort.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ ContextProperties ContextImplOrt::GetContextProperties(
267267

268268
return ContextProperties(
269269
InputOperandLayout::kNchw, Resample2DAxes::kChannelsFirst,
270+
BatchNormalizationAxis::kChannelsFirst,
270271
/*tensor_byte_length_limit=*/kTensorByteLengthLimit,
271272
{/*input=*/SupportedDataTypes::All(),
272273
/*constant=*/SupportedDataTypes::All(),
@@ -476,14 +477,16 @@ base::WeakPtr<WebNNContextImpl> ContextImplOrt::AsWeakPtr() {
476477
}
477478

478479
void ContextImplOrt::CreateGraphImpl(
480+
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
479481
mojom::GraphInfoPtr graph_info,
480482
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
481483
base::flat_map<uint64_t, std::unique_ptr<WebNNConstantOperand>>
482484
constant_operands,
483485
CreateGraphImplCallback callback) {
484-
GraphImplOrt::CreateAndBuild(
485-
std::move(graph_info), std::move(compute_resource_info),
486-
std::move(constant_operands), this, std::move(callback));
486+
GraphImplOrt::CreateAndBuild(std::move(receiver), std::move(graph_info),
487+
std::move(compute_resource_info),
488+
std::move(constant_operands), this,
489+
std::move(callback));
487490
}
488491

489492
void ContextImplOrt::CreateTensorImpl(

services/webnn/ort/context_impl_ort.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class ContextImplOrt final : public WebNNContextImpl {
6161

6262
private:
6363
void CreateGraphImpl(
64+
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
6465
mojom::GraphInfoPtr graph_info,
6566
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
6667
base::flat_map<uint64_t, std::unique_ptr<WebNNConstantOperand>>

services/webnn/ort/graph_builder_ort.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -823,13 +823,6 @@ GraphBuilderOrt::AddBatchNormalizationOperation(
823823

824824
const std::vector<uint32_t>& input_shape =
825825
GetOperand(batch_normalization.input_operand_id).descriptor.shape();
826-
// TODO: Support NHWC layout-
827-
// https://github.com/shiyi9801/chromium/issues/77
828-
if (batch_normalization.axis != 1) {
829-
return NewNotSupportedError(
830-
"Unsupported axis since BatchNormalization only supports NCHW layout "
831-
"currently. ");
832-
}
833826
uint32_t input_channel = input_shape[1];
834827
std::vector<uint32_t> constant_dims = {input_channel};
835828

@@ -2095,12 +2088,6 @@ GraphBuilderOrt::AddInstanceNormalizationOperation(
20952088

20962089
const std::vector<uint32_t>& input_shape =
20972090
GetOperand(instance_normalization.input_operand_id).descriptor.shape();
2098-
// TODO(crbug.com/387312212): Support NHWC layout
2099-
if (instance_normalization.layout ==
2100-
mojom::InputOperandLayout::kChannelsLast) {
2101-
return NewNotSupportedError(
2102-
"[WebNN] Currently InstanceNormalization only supports NCHW layout.");
2103-
}
21042091
CHECK_EQ(context_properties_.input_operand_layout, InputOperandLayout::kNchw);
21052092
uint32_t input_channel = input_shape[1];
21062093
std::vector<uint32_t> constant_dims = {input_channel};

services/webnn/ort/graph_impl_ort.cc

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class GraphImplOrt::ComputeResources {
132132

133133
// static
134134
void GraphImplOrt::CreateAndBuild(
135+
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
135136
mojom::GraphInfoPtr graph_info,
136137
ComputeResourceInfo compute_resource_info,
137138
base::flat_map<uint64_t, std::unique_ptr<WebNNConstantOperand>>
@@ -141,8 +142,9 @@ void GraphImplOrt::CreateAndBuild(
141142
ScopedTrace scoped_trace("GraphImplOrt::CreateAndBuild");
142143

143144
auto wrapped_callback = base::BindPostTaskToCurrentDefault(
144-
base::BindOnce(&GraphImplOrt::DidCreateAndBuild, context->AsWeakPtr(),
145-
std::move(compute_resource_info), std::move(callback)));
145+
base::BindOnce(&GraphImplOrt::DidCreateAndBuild, std::move(receiver),
146+
context->AsWeakPtr(), std::move(compute_resource_info),
147+
std::move(callback)));
146148

147149
base::ThreadPool::PostTaskAndReplyWithResult(
148150
FROM_HERE,
@@ -224,6 +226,7 @@ GraphImplOrt::CreateAndBuildOnBackgroundThread(
224226

225227
// static
226228
void GraphImplOrt::DidCreateAndBuild(
229+
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
227230
base::WeakPtr<WebNNContextImpl> context,
228231
ComputeResourceInfo compute_resource_info,
229232
WebNNContextImpl::CreateGraphImplCallback callback,
@@ -241,17 +244,20 @@ void GraphImplOrt::DidCreateAndBuild(
241244
}
242245

243246
std::move(callback).Run(base::WrapUnique(new GraphImplOrt(
244-
std::move(compute_resource_info), std::move(result.value()),
245-
static_cast<ContextImplOrt*>(context.get()))));
247+
std::move(receiver), std::move(compute_resource_info),
248+
std::move(result.value()), static_cast<ContextImplOrt*>(context.get()))));
246249
}
247250

248251
GraphImplOrt::~GraphImplOrt() = default;
249252

250253
GraphImplOrt::GraphImplOrt(
254+
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
251255
ComputeResourceInfo compute_resource_info,
252256
std::unique_ptr<GraphImplOrt::ComputeResources> compute_resources,
253257
ContextImplOrt* context)
254-
: WebNNGraphImpl(context, std::move(compute_resource_info)) {
258+
: WebNNGraphImpl(std::move(receiver),
259+
context,
260+
std::move(compute_resource_info)) {
255261
compute_resources_state_ =
256262
base::MakeRefCounted<QueueableResourceState<ComputeResources>>(
257263
std::move(compute_resources));

services/webnn/ort/graph_impl_ort.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class SessionOptions;
3939
class GraphImplOrt final : public WebNNGraphImpl {
4040
public:
4141
static void CreateAndBuild(
42+
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
4243
mojom::GraphInfoPtr graph_info,
4344
ComputeResourceInfo compute_resource_info,
4445
base::flat_map<uint64_t, std::unique_ptr<WebNNConstantOperand>>
@@ -53,7 +54,8 @@ class GraphImplOrt final : public WebNNGraphImpl {
5354
private:
5455
class ComputeResources;
5556

56-
GraphImplOrt(ComputeResourceInfo compute_resource_info,
57+
GraphImplOrt(mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
58+
ComputeResourceInfo compute_resource_info,
5759
std::unique_ptr<ComputeResources> compute_resources,
5860
ContextImplOrt* context);
5961

@@ -67,6 +69,7 @@ class GraphImplOrt final : public WebNNGraphImpl {
6769
ScopedTrace scoped_trace);
6870

6971
static void DidCreateAndBuild(
72+
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
7073
base::WeakPtr<WebNNContextImpl> context,
7174
ComputeResourceInfo compute_resource_info,
7275
WebNNContextImpl::CreateGraphImplCallback callback,

services/webnn/webnn_context_provider_impl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ void WebNNContextProviderImpl::CreateWebNNContext(
255255

256256
OrtLoggingLevel ort_logging_level = ORT_LOGGING_LEVEL_WARNING;
257257
if (base::CommandLine::ForCurrentProcess()->HasSwitch(
258-
switches::kWebNNOrtLoggingLevel)) {
258+
switches::kWebNNOrtLoggingLevel)) {
259259
std::string user_logging_level =
260260
base::CommandLine::ForCurrentProcess()->GetSwitchValueASCII(
261261
switches::kWebNNOrtLoggingLevel);

0 commit comments

Comments
 (0)