From e0765f79d7242ad25e89d3aae30ada3deadb5642 Mon Sep 17 00:00:00 2001 From: dtang317 Date: Wed, 30 Oct 2024 14:32:09 -0700 Subject: [PATCH 1/2] add an extra step (clip) to cast op when casting to bool type --- .../src/Operators/DmlOperatorCast.cpp | 55 ++++++++++++++++++- .../test/providers/cpu/tensor/cast_op_test.cc | 4 -- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp index 02fb72b5a073a..0ae51a7eb8809 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp @@ -28,9 +28,60 @@ class DmlOperatorCast : public DmlOperator castDesc.InputTensor = inputDescs.data(); castDesc.OutputTensor = outputDescs.data(); - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc }; + if (kernelInfo.GetOutputEdgeDescription(0).tensorDataType == static_cast(ONNX_NAMESPACE::TensorProto_DataType_BOOL)) + { + DML_OPERATOR_DESC dmlCastDesc = { DML_OPERATOR_CAST, &castDesc }; - SetDmlOperatorDesc(opDesc, kernelInfo); + DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC clipDesc = {}; + clipDesc.InputTensor = outputDescs.data(); + clipDesc.OutputTensor = outputDescs.data(); + DML_SCALAR_UNION minValue; + minValue.Int32 = 0; + DML_SCALAR_UNION maxValue; + maxValue.Int32 = 1; + clipDesc.Min = minValue; + clipDesc.Max = maxValue; + + DML_OPERATOR_DESC dmlClipDesc = { DML_OPERATOR_ELEMENT_WISE_CLIP1, &clipDesc }; + + std::vector opDescs = { &dmlCastDesc, &dmlClipDesc }; + + DML_INPUT_GRAPH_EDGE_DESC inputToCastEdge = {}; + inputToCastEdge.GraphInputIndex = 0; + inputToCastEdge.ToNodeIndex = 0; + inputToCastEdge.ToNodeInputIndex = 0; + + DML_INTERMEDIATE_GRAPH_EDGE_DESC castToClipEdge = {}; + castToClipEdge.FromNodeIndex = 0; + castToClipEdge.FromNodeOutputIndex = 0; + castToClipEdge.ToNodeIndex = 1; + castToClipEdge.ToNodeInputIndex = 0; + + DML_OUTPUT_GRAPH_EDGE_DESC clipToOutputEdge = {}; + clipToOutputEdge.FromNodeIndex = 1; + clipToOutputEdge.FromNodeOutputIndex = 0; + clipToOutputEdge.GraphOutputIndex = 0; + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodes = opDescs.data(); + + operatorGraphDesc.inputEdgeCount = 1; + operatorGraphDesc.inputEdges = &inputToCastEdge; + + operatorGraphDesc.intermediateEdgeCount = 1; + operatorGraphDesc.intermediateEdges = &castToClipEdge; + + operatorGraphDesc.outputEdgeCount = 1; + operatorGraphDesc.outputEdges = &clipToOutputEdge; + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); + } + else + { + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } } void Compute(const MLOperatorKernelContext& kernelContext) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 421561a5a859b..97d7474f9cb55 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -149,10 +149,6 @@ using CastNonStringTypes = MLFloat16, BFloat16>; TEST(CastOpTest, NonStringTypes) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: Expected equality of these values: true and true"; - } boost::mp11::mp_for_each>( CastNonStringTester{}); From ee79c9a18838fe46cf627527e4524ec465e8a3ee Mon Sep 17 00:00:00 2001 From: dtang317 Date: Thu, 31 Oct 2024 09:49:41 -0700 Subject: [PATCH 2/2] update clipDesc, clean up format --- .../src/Operators/DmlOperatorCast.cpp | 8 ++------ onnxruntime/test/providers/cpu/tensor/cast_op_test.cc | 1 - 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp index 0ae51a7eb8809..389bdee3a365b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp @@ -35,12 +35,8 @@ class DmlOperatorCast : public DmlOperator DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC clipDesc = {}; clipDesc.InputTensor = outputDescs.data(); clipDesc.OutputTensor = outputDescs.data(); - DML_SCALAR_UNION minValue; - minValue.Int32 = 0; - DML_SCALAR_UNION maxValue; - maxValue.Int32 = 1; - clipDesc.Min = minValue; - clipDesc.Max = maxValue; + clipDesc.Min.UInt8 = 0; + clipDesc.Max.UInt8 = 1; DML_OPERATOR_DESC dmlClipDesc = { DML_OPERATOR_ELEMENT_WISE_CLIP1, &clipDesc }; diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 97d7474f9cb55..384adb5916cc1 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -149,7 +149,6 @@ using CastNonStringTypes = MLFloat16, BFloat16>; TEST(CastOpTest, NonStringTypes) { - boost::mp11::mp_for_each>( CastNonStringTester{}); }