Skip to content

Commit 55e0128

Browse files
authored
[DML EP] Cast to bool correctly, adding explicit clip after cast (#22665)
### Description The CastNonStringTester test in CastOpTest was failing due to bitwise mismatches when casting other types to bool. This was caused by bool being represented as uint8 in DML. Added a clipping step in DmlOperatorCast to ensure correct bitwise matching after casting to bool ref: https://dev.azure.com/microsoft/OS/_workitems/edit/44572678 ### Motivation and Context
1 parent 1b60209 commit 55e0128

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,56 @@ class DmlOperatorCast : public DmlOperator
2828
castDesc.InputTensor = inputDescs.data();
2929
castDesc.OutputTensor = outputDescs.data();
3030

31-
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc };
31+
if (kernelInfo.GetOutputEdgeDescription(0).tensorDataType == static_cast<MLOperatorTensorDataType>(ONNX_NAMESPACE::TensorProto_DataType_BOOL))
32+
{
33+
DML_OPERATOR_DESC dmlCastDesc = { DML_OPERATOR_CAST, &castDesc };
3234

33-
SetDmlOperatorDesc(opDesc, kernelInfo);
35+
DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC clipDesc = {};
36+
clipDesc.InputTensor = outputDescs.data();
37+
clipDesc.OutputTensor = outputDescs.data();
38+
clipDesc.Min.UInt8 = 0;
39+
clipDesc.Max.UInt8 = 1;
40+
41+
DML_OPERATOR_DESC dmlClipDesc = { DML_OPERATOR_ELEMENT_WISE_CLIP1, &clipDesc };
42+
43+
std::vector<const DML_OPERATOR_DESC*> opDescs = { &dmlCastDesc, &dmlClipDesc };
44+
45+
DML_INPUT_GRAPH_EDGE_DESC inputToCastEdge = {};
46+
inputToCastEdge.GraphInputIndex = 0;
47+
inputToCastEdge.ToNodeIndex = 0;
48+
inputToCastEdge.ToNodeInputIndex = 0;
49+
50+
DML_INTERMEDIATE_GRAPH_EDGE_DESC castToClipEdge = {};
51+
castToClipEdge.FromNodeIndex = 0;
52+
castToClipEdge.FromNodeOutputIndex = 0;
53+
castToClipEdge.ToNodeIndex = 1;
54+
castToClipEdge.ToNodeInputIndex = 0;
55+
56+
DML_OUTPUT_GRAPH_EDGE_DESC clipToOutputEdge = {};
57+
clipToOutputEdge.FromNodeIndex = 1;
58+
clipToOutputEdge.FromNodeOutputIndex = 0;
59+
clipToOutputEdge.GraphOutputIndex = 0;
60+
61+
MLOperatorGraphDesc operatorGraphDesc = {};
62+
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
63+
operatorGraphDesc.nodes = opDescs.data();
64+
65+
operatorGraphDesc.inputEdgeCount = 1;
66+
operatorGraphDesc.inputEdges = &inputToCastEdge;
67+
68+
operatorGraphDesc.intermediateEdgeCount = 1;
69+
operatorGraphDesc.intermediateEdges = &castToClipEdge;
70+
71+
operatorGraphDesc.outputEdgeCount = 1;
72+
operatorGraphDesc.outputEdges = &clipToOutputEdge;
73+
74+
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo);
75+
}
76+
else
77+
{
78+
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc };
79+
SetDmlOperatorDesc(opDesc, kernelInfo);
80+
}
3481
}
3582

3683
void Compute(const MLOperatorKernelContext& kernelContext)

onnxruntime/test/providers/cpu/tensor/cast_op_test.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,6 @@ using CastNonStringTypes =
149149
MLFloat16, BFloat16>;
150150

151151
TEST(CastOpTest, NonStringTypes) {
152-
// TODO: Unskip when fixed #41968513
153-
if (DefaultDmlExecutionProvider().get() != nullptr) {
154-
GTEST_SKIP() << "Skipping because of the following error: Expected equality of these values: true and true";
155-
}
156-
157152
boost::mp11::mp_for_each<boost::mp11::mp_product<std::pair, CastNonStringTypes, CastNonStringTypes>>(
158153
CastNonStringTester{});
159154
}

0 commit comments

Comments
 (0)