diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 19bfdc83aa3d4..3ab3c0deca377 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1128,6 +1128,7 @@ Do not modify directly.* |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| +|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)| | | | | |**Operator Domain:** *com.microsoft.dml*|||| diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index ec0e637bce077..58d4c08b1ade8 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -271,7 +271,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cuda_rocm_eps)); #ifdef ENABLE_TRAINING @@ -280,7 +280,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); #endif - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp new file mode 100644 index 0000000000000..5c6e98c17aeca --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorSkipLayerNormalization : public DmlOperator +{ +public: + DmlOperatorSkipLayerNormalization(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext) + { + std::vector> kernelInputIndices = {0, 1, 2, 3, 4}; + + DmlOperator::Initialize( + kernelCreationContext, + kernelInputIndices, + std::nullopt, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0), + std::nullopt, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorDimensionCount(0)); + + const float epsilon = kernelCreationContext.GetOptionalAttribute(AttrName::Epsilon, DefaultEpsilon); + + int32_t onnxAxis = kernelCreationContext.GetOptionalAttribute(AttrName::Axis, -1); + uint32_t inputDimCount = kernelCreationContext.GetTensorShapeDescription().GetInputTensorDimensionCount(0); + onnxAxis = OperatorHelper::HandleNegativeAxis(onnxAxis, inputDimCount); + std::vector onnxAxes(inputDimCount - onnxAxis); + std::iota(onnxAxes.begin(), onnxAxes.end(), onnxAxis); + + assert(m_inputTensorDescs.size() == 5); + assert(m_outputTensorDescs.size() == 1); + + auto inputDesc = m_inputTensorDescs[0].GetDmlDesc(); + auto skipDesc = m_inputTensorDescs[1].GetDmlDesc(); + auto gammaDesc = m_inputTensorDescs[2].GetDmlDesc(); + auto betaDesc = m_inputTensorDescs[3].GetDmlDesc(); + auto biasDesc = m_inputTensorDescs[4].GetDmlDesc(); + auto outputDesc = m_outputTensorDescs[0].GetDmlDesc(); + + TensorDesc inputSkipBiasTensorDesc(m_inputTensorDescs[0].GetDmlDataType(), m_inputTensorDescs[0].GetSizes()); + DML_TENSOR_DESC inputSkipBiasDmlTensorDesc = inputSkipBiasTensorDesc.GetDmlDesc(); + + DML_ELEMENT_WISE_ADD_OPERATOR_DESC inputSkipAddDesc = {}; + inputSkipAddDesc.ATensor = &inputDesc; + inputSkipAddDesc.BTensor = &skipDesc; + inputSkipAddDesc.OutputTensor = &inputSkipBiasDmlTensorDesc; + DML_OPERATOR_DESC inputSkipAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipAddDesc }; + + DML_ELEMENT_WISE_ADD_OPERATOR_DESC inputSkipBiasAddDesc = {}; + inputSkipBiasAddDesc.ATensor = &inputSkipBiasDmlTensorDesc; + inputSkipBiasAddDesc.BTensor = &biasDesc; + inputSkipBiasAddDesc.OutputTensor = &inputSkipBiasDmlTensorDesc; + DML_OPERATOR_DESC inputSkipBiasAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipBiasAddDesc }; + + DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC mvnDesc = {}; + mvnDesc.InputTensor = &inputSkipBiasDmlTensorDesc; + mvnDesc.ScaleTensor = &gammaDesc; + mvnDesc.BiasTensor = betaDesc.Desc ? &betaDesc : nullptr; + mvnDesc.OutputTensor = &outputDesc; + mvnDesc.Axes = onnxAxes.data(); + mvnDesc.AxisCount = gsl::narrow_cast(onnxAxes.size()); + mvnDesc.NormalizeVariance = true; + mvnDesc.Epsilon = epsilon; + mvnDesc.FusedActivation = nullptr; + DML_OPERATOR_DESC mvnOpDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &mvnDesc }; + + // Construct the graph + std::vector opDescs; + opDescs.reserve(3); + + std::vector inputEdges; + inputEdges.reserve(5); + + std::vector intermediateEdges; + intermediateEdges.reserve(2); + + std::vector outputEdges; + outputEdges.reserve(1); + + // Insert the Input + Skip operation into the graph + opDescs.push_back(&inputSkipAddOpDesc); + + DML_INPUT_GRAPH_EDGE_DESC dataInputEdge = {}; + dataInputEdge.GraphInputIndex = 0; + dataInputEdge.ToNodeIndex = 0; + dataInputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(std::move(dataInputEdge)); + + DML_INPUT_GRAPH_EDGE_DESC skipInputEdge = {}; + skipInputEdge.GraphInputIndex = 1; + skipInputEdge.ToNodeIndex = 0; + skipInputEdge.ToNodeInputIndex = 1; + inputEdges.push_back(std::move(skipInputEdge)); + + // Insert the InputSkip + Bias operation into the graph + if (biasDesc.Desc) + { + opDescs.push_back(&inputSkipBiasAddOpDesc); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {}; + intermediateEdge.FromNodeIndex = 0; + intermediateEdge.FromNodeOutputIndex = 0; + intermediateEdge.ToNodeIndex = 1; + intermediateEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(std::move(intermediateEdge)); + + DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {}; + biasInputEdge.GraphInputIndex = 4; + biasInputEdge.ToNodeIndex = 1; + biasInputEdge.ToNodeInputIndex = 1; + inputEdges.push_back(std::move(biasInputEdge)); + } + + // Insert the MVN operation into the graph + opDescs.push_back(&mvnOpDesc); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {}; + intermediateEdge.FromNodeIndex = biasDesc.Desc ? 1 : 0; + intermediateEdge.FromNodeOutputIndex = 0; + intermediateEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1; + intermediateEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(std::move(intermediateEdge)); + + DML_INPUT_GRAPH_EDGE_DESC gammaInputEdge = {}; + gammaInputEdge.GraphInputIndex = 2; + gammaInputEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1; + gammaInputEdge.ToNodeInputIndex = 1; + inputEdges.push_back(std::move(gammaInputEdge)); + + if (betaDesc.Desc) + { + DML_INPUT_GRAPH_EDGE_DESC betaInputEdge = {}; + betaInputEdge.GraphInputIndex = 3; + betaInputEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1; + betaInputEdge.ToNodeInputIndex = 2; + inputEdges.push_back(std::move(betaInputEdge)); + } + + DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {}; + outputEdge.GraphOutputIndex = 0; + outputEdge.FromNodeIndex = biasDesc.Desc ? 2 : 1; + outputEdge.FromNodeOutputIndex = 0; + outputEdges.push_back(std::move(outputEdge)); + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(SkipLayerNormalization, DmlOperatorSkipLayerNormalization); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index bef3bf66d33b8..78fb9f73391d1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -100,6 +100,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization); DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15); DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization); DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization17); +DML_OP_EXTERN_CREATION_FUNCTION(SkipLayerNormalization); DML_OP_EXTERN_CREATION_FUNCTION(LRN); DML_OP_EXTERN_CREATION_FUNCTION(MeanVarianceNormalization); DML_OP_EXTERN_CREATION_FUNCTION(LpNormalization); @@ -748,6 +749,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, {REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)}, + {REG_INFO_MS( 1, SkipLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, }; template diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 7df7a51b78f04..1b14aa327169b 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1405,6 +1405,7 @@ using ShapeInferenceHelper_LRN = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MeanVarianceNormalization = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_SkipLayerNormalization = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_LpNormalization = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_RNN = RecurrentHelper; using ShapeInferenceHelper_GRU = RecurrentHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 8b8f4fa4a4f0a..7dc6096c680cb 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -395,6 +395,7 @@ namespace OperatorHelper static const int sc_sinceVer_FusedMatMul = 1; static const int sc_sinceVer_QLinearSigmoid = 1; static const int sc_sinceVer_Attention = 1; + static const int sc_sinceVer_SkipLayerNormalization = 1; } // namespace MsftOperatorSet1 } // namespace OperatorHelper diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index f48be8b8c8bf8..c90a9f0466d96 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -37,6 +37,7 @@ static void RunTest( std::vector output_dims = input_dims; auto rocm_ep = DefaultRocmExecutionProvider(); + auto dml_ep = DefaultDmlExecutionProvider(); if (!use_float16) { OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain); test.AddInput("input", input_dims, input_data); @@ -55,6 +56,7 @@ static void RunTest( test.AddOutput("output", output_dims, output_data); test.Run(); } else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) || + dml_ep != nullptr || rocm_ep != nullptr) { OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain); test.AddInput("input", input_dims, ToFloat16(input_data)); @@ -73,7 +75,9 @@ static void RunTest( test.AddOutput("output", output_dims, ToFloat16(output_data)); std::vector> execution_providers; - if (rocm_ep != nullptr) { + if (dml_ep != nullptr) { + execution_providers.push_back(DefaultDmlExecutionProvider()); + } else if (rocm_ep != nullptr) { execution_providers.push_back(DefaultRocmExecutionProvider()); } else { execution_providers.push_back(DefaultCudaExecutionProvider()); diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index 2ea1f3382357c..634b1565491b7 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -339,8 +339,10 @@ struct TensorCheck { const bool has_rel_err = params.relative_error_.has_value(); float threshold = 0.001f; -#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM) threshold = 0.005f; +#elif defined(USE_DML) + threshold = 0.008f; #endif for (int i = 0; i < size; ++i) { if (std::isnan(f_expected[i])) {