diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp index 02fb72b5a073a..389bdee3a365b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp @@ -28,9 +28,56 @@ class DmlOperatorCast : public DmlOperator castDesc.InputTensor = inputDescs.data(); castDesc.OutputTensor = outputDescs.data(); - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc }; + if (kernelInfo.GetOutputEdgeDescription(0).tensorDataType == static_cast(ONNX_NAMESPACE::TensorProto_DataType_BOOL)) + { + DML_OPERATOR_DESC dmlCastDesc = { DML_OPERATOR_CAST, &castDesc }; - SetDmlOperatorDesc(opDesc, kernelInfo); + DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC clipDesc = {}; + clipDesc.InputTensor = outputDescs.data(); + clipDesc.OutputTensor = outputDescs.data(); + clipDesc.Min.UInt8 = 0; + clipDesc.Max.UInt8 = 1; + + DML_OPERATOR_DESC dmlClipDesc = { DML_OPERATOR_ELEMENT_WISE_CLIP1, &clipDesc }; + + std::vector opDescs = { &dmlCastDesc, &dmlClipDesc }; + + DML_INPUT_GRAPH_EDGE_DESC inputToCastEdge = {}; + inputToCastEdge.GraphInputIndex = 0; + inputToCastEdge.ToNodeIndex = 0; + inputToCastEdge.ToNodeInputIndex = 0; + + DML_INTERMEDIATE_GRAPH_EDGE_DESC castToClipEdge = {}; + castToClipEdge.FromNodeIndex = 0; + castToClipEdge.FromNodeOutputIndex = 0; + castToClipEdge.ToNodeIndex = 1; + castToClipEdge.ToNodeInputIndex = 0; + + DML_OUTPUT_GRAPH_EDGE_DESC clipToOutputEdge = {}; + clipToOutputEdge.FromNodeIndex = 1; + clipToOutputEdge.FromNodeOutputIndex = 0; + clipToOutputEdge.GraphOutputIndex = 0; + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodes = opDescs.data(); + + operatorGraphDesc.inputEdgeCount = 1; + operatorGraphDesc.inputEdges = &inputToCastEdge; + + operatorGraphDesc.intermediateEdgeCount = 1; + operatorGraphDesc.intermediateEdges = &castToClipEdge; + + operatorGraphDesc.outputEdgeCount = 1; + operatorGraphDesc.outputEdges = &clipToOutputEdge; + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); + } + else + { + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } } void Compute(const MLOperatorKernelContext& kernelContext) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 421561a5a859b..384adb5916cc1 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -149,11 +149,6 @@ using CastNonStringTypes = MLFloat16, BFloat16>; TEST(CastOpTest, NonStringTypes) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: Expected equality of these values: true and true"; - } - boost::mp11::mp_for_each>( CastNonStringTester{}); }