Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,56 @@
castDesc.InputTensor = inputDescs.data();
castDesc.OutputTensor = outputDescs.data();

DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc };
if (kernelInfo.GetOutputEdgeDescription(0).tensorDataType == static_cast<MLOperatorTensorDataType>(ONNX_NAMESPACE::TensorProto_DataType_BOOL))
{
DML_OPERATOR_DESC dmlCastDesc = { DML_OPERATOR_CAST, &castDesc };

SetDmlOperatorDesc(opDesc, kernelInfo);
DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC clipDesc = {};
clipDesc.InputTensor = outputDescs.data();
clipDesc.OutputTensor = outputDescs.data();
clipDesc.Min.UInt8 = 0;
clipDesc.Max.UInt8 = 1;

DML_OPERATOR_DESC dmlClipDesc = { DML_OPERATOR_ELEMENT_WISE_CLIP1, &clipDesc };

std::vector<const DML_OPERATOR_DESC*> opDescs = { &dmlCastDesc, &dmlClipDesc };

DML_INPUT_GRAPH_EDGE_DESC inputToCastEdge = {};
inputToCastEdge.GraphInputIndex = 0;
inputToCastEdge.ToNodeIndex = 0;
inputToCastEdge.ToNodeInputIndex = 0;

DML_INTERMEDIATE_GRAPH_EDGE_DESC castToClipEdge = {};
castToClipEdge.FromNodeIndex = 0;

Check warning on line 51 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp:51: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
castToClipEdge.FromNodeOutputIndex = 0;
castToClipEdge.ToNodeIndex = 1;
castToClipEdge.ToNodeInputIndex = 0;

DML_OUTPUT_GRAPH_EDGE_DESC clipToOutputEdge = {};
clipToOutputEdge.FromNodeIndex = 1;
clipToOutputEdge.FromNodeOutputIndex = 0;
clipToOutputEdge.GraphOutputIndex = 0;

MLOperatorGraphDesc operatorGraphDesc = {};
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
operatorGraphDesc.nodes = opDescs.data();

operatorGraphDesc.inputEdgeCount = 1;
operatorGraphDesc.inputEdges = &inputToCastEdge;

operatorGraphDesc.intermediateEdgeCount = 1;
operatorGraphDesc.intermediateEdges = &castToClipEdge;

operatorGraphDesc.outputEdgeCount = 1;
operatorGraphDesc.outputEdges = &clipToOutputEdge;

SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo);

Check warning on line 74 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp:74: Add #include <utility> for move [build/include_what_you_use] [4]
}
else
{

Check warning on line 77 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp:77: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
}

void Compute(const MLOperatorKernelContext& kernelContext)
Expand Down
5 changes: 0 additions & 5 deletions onnxruntime/test/providers/cpu/tensor/cast_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,6 @@ using CastNonStringTypes =
MLFloat16, BFloat16>;

TEST(CastOpTest, NonStringTypes) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: Expected equality of these values: true and true";
}

boost::mp11::mp_for_each<boost::mp11::mp_product<std::pair, CastNonStringTypes, CastNonStringTypes>>(
CastNonStringTester{});
}
Expand Down
Loading