Skip to content

Commit 88b63d0

Browse files
authored
[WebNN] Rename outputs for gru and lstm (#26196)
WebNN spec changed the output names of `gru` and `lstm` for `OpSupportLimits` in webmachinelearning/webnn#857, renamed them in WebNN EP as well.
1 parent 189e673 commit 88b63d0

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,15 +237,15 @@ bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node,
237237
bool Y_h_supported = has_Y_h && GetType(*output_defs[1], Y_h_type, logger);
238238

239239
if (Y_supported && !Y_h_supported) {
240-
return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger);
240+
return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "output1", "Y", logger);
241241
} else if (!Y_supported && Y_h_supported) {
242-
return IsDataTypeSupportedByOp(op_type, Y_h_type, wnn_limits, "outputs", "Y_h", logger);
242+
return IsDataTypeSupportedByOp(op_type, Y_h_type, wnn_limits, "output0", "Y_h", logger);
243243
} else if (Y_supported && Y_h_supported) {
244244
if (Y_type != Y_h_type) {
245245
LOGS(logger, VERBOSE) << "[GRU] Output data types must be the same.";
246246
return false;
247247
}
248-
return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger);
248+
return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "output1", "Y", logger);
249249
} else {
250250
LOGS(logger, VERBOSE) << "[GRU] No output found.";
251251
return false;

onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,13 @@ bool LstmOpBuilder::HasSupportedOutputsImpl(const Node& node,
259259
bool has_Y_c = TensorExists(output_defs, 2);
260260

261261
if (has_Y && GetType(*output_defs[0], Y_type, logger)) {
262-
return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger);
262+
return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "output2", "Y", logger);
263263
}
264264
if (has_Y_h && GetType(*output_defs[1], Y_h_type, logger)) {
265-
return IsDataTypeSupportedByOp(op_type, Y_h_type, wnn_limits, "outputs", "Y_h", logger);
265+
return IsDataTypeSupportedByOp(op_type, Y_h_type, wnn_limits, "output0", "Y_h", logger);
266266
}
267267
if (has_Y_c && GetType(*output_defs[2], Y_c_type, logger)) {
268-
return IsDataTypeSupportedByOp(op_type, Y_c_type, wnn_limits, "outputs", "Y_c", logger);
268+
return IsDataTypeSupportedByOp(op_type, Y_c_type, wnn_limits, "output1", "Y_c", logger);
269269
}
270270

271271
return false;

0 commit comments

Comments
 (0)