diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 19bfdc83aa3d4..3ab3c0deca377 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1128,6 +1128,7 @@ Do not modify directly.*
|QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
+|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)|
| |
| |
|**Operator Domain:** *com.microsoft.dml*||||
diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc
index ec0e637bce077..58d4c08b1ade8 100644
--- a/onnxruntime/core/optimizer/graph_transformer_utils.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc
@@ -271,7 +271,7 @@ InlinedVector> GenerateTransformers(
transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps));
- transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps));
+ transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique(cuda_rocm_eps));
#ifdef ENABLE_TRAINING
@@ -280,7 +280,7 @@ InlinedVector> GenerateTransformers(
transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps));
#endif
- transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps));
+ transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps));
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp
new file mode 100644
index 0000000000000..5c6e98c17aeca
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp
@@ -0,0 +1,164 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "precomp.h"
+
+namespace Dml
+{
+
+class DmlOperatorSkipLayerNormalization : public DmlOperator
+{
+public:
+ DmlOperatorSkipLayerNormalization(const MLOperatorKernelCreationContext& kernelCreationContext)
+ : DmlOperator(kernelCreationContext)
+ {
+ std::vector> kernelInputIndices = {0, 1, 2, 3, 4};
+
+ DmlOperator::Initialize(
+ kernelCreationContext,
+ kernelInputIndices,
+ std::nullopt,
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0),
+ std::nullopt,
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorDimensionCount(0));
+
+ const float epsilon = kernelCreationContext.GetOptionalAttribute(AttrName::Epsilon, DefaultEpsilon);
+
+ int32_t onnxAxis = kernelCreationContext.GetOptionalAttribute(AttrName::Axis, -1);
+ uint32_t inputDimCount = kernelCreationContext.GetTensorShapeDescription().GetInputTensorDimensionCount(0);
+ onnxAxis = OperatorHelper::HandleNegativeAxis(onnxAxis, inputDimCount);
+ std::vector onnxAxes(inputDimCount - onnxAxis);
+ std::iota(onnxAxes.begin(), onnxAxes.end(), onnxAxis);
+
+ assert(m_inputTensorDescs.size() == 5);
+ assert(m_outputTensorDescs.size() == 1);
+
+ auto inputDesc = m_inputTensorDescs[0].GetDmlDesc();
+ auto skipDesc = m_inputTensorDescs[1].GetDmlDesc();
+ auto gammaDesc = m_inputTensorDescs[2].GetDmlDesc();
+ auto betaDesc = m_inputTensorDescs[3].GetDmlDesc();
+ auto biasDesc = m_inputTensorDescs[4].GetDmlDesc();
+ auto outputDesc = m_outputTensorDescs[0].GetDmlDesc();
+
+ TensorDesc inputSkipBiasTensorDesc(m_inputTensorDescs[0].GetDmlDataType(), m_inputTensorDescs[0].GetSizes());
+ DML_TENSOR_DESC inputSkipBiasDmlTensorDesc = inputSkipBiasTensorDesc.GetDmlDesc();
+
+ DML_ELEMENT_WISE_ADD_OPERATOR_DESC inputSkipAddDesc = {};
+ inputSkipAddDesc.ATensor = &inputDesc;
+ inputSkipAddDesc.BTensor = &skipDesc;
+ inputSkipAddDesc.OutputTensor = &inputSkipBiasDmlTensorDesc;
+ DML_OPERATOR_DESC inputSkipAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipAddDesc };
+
+ DML_ELEMENT_WISE_ADD_OPERATOR_DESC inputSkipBiasAddDesc = {};
+ inputSkipBiasAddDesc.ATensor = &inputSkipBiasDmlTensorDesc;
+ inputSkipBiasAddDesc.BTensor = &biasDesc;
+ inputSkipBiasAddDesc.OutputTensor = &inputSkipBiasDmlTensorDesc;
+ DML_OPERATOR_DESC inputSkipBiasAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipBiasAddDesc };
+
+ DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC mvnDesc = {};
+ mvnDesc.InputTensor = &inputSkipBiasDmlTensorDesc;
+ mvnDesc.ScaleTensor = &gammaDesc;
+ mvnDesc.BiasTensor = betaDesc.Desc ? &betaDesc : nullptr;
+ mvnDesc.OutputTensor = &outputDesc;
+ mvnDesc.Axes = onnxAxes.data();
+ mvnDesc.AxisCount = gsl::narrow_cast(onnxAxes.size());
+ mvnDesc.NormalizeVariance = true;
+ mvnDesc.Epsilon = epsilon;
+ mvnDesc.FusedActivation = nullptr;
+ DML_OPERATOR_DESC mvnOpDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &mvnDesc };
+
+ // Construct the graph
+ std::vector opDescs;
+ opDescs.reserve(3);
+
+ std::vector inputEdges;
+ inputEdges.reserve(5);
+
+ std::vector intermediateEdges;
+ intermediateEdges.reserve(2);
+
+ std::vector outputEdges;
+ outputEdges.reserve(1);
+
+ // Insert the Input + Skip operation into the graph
+ opDescs.push_back(&inputSkipAddOpDesc);
+
+ DML_INPUT_GRAPH_EDGE_DESC dataInputEdge = {};
+ dataInputEdge.GraphInputIndex = 0;
+ dataInputEdge.ToNodeIndex = 0;
+ dataInputEdge.ToNodeInputIndex = 0;
+ inputEdges.push_back(std::move(dataInputEdge));
+
+ DML_INPUT_GRAPH_EDGE_DESC skipInputEdge = {};
+ skipInputEdge.GraphInputIndex = 1;
+ skipInputEdge.ToNodeIndex = 0;
+ skipInputEdge.ToNodeInputIndex = 1;
+ inputEdges.push_back(std::move(skipInputEdge));
+
+ // Insert the InputSkip + Bias operation into the graph
+ if (biasDesc.Desc)
+ {
+ opDescs.push_back(&inputSkipBiasAddOpDesc);
+
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
+ intermediateEdge.FromNodeIndex = 0;
+ intermediateEdge.FromNodeOutputIndex = 0;
+ intermediateEdge.ToNodeIndex = 1;
+ intermediateEdge.ToNodeInputIndex = 0;
+ intermediateEdges.push_back(std::move(intermediateEdge));
+
+ DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {};
+ biasInputEdge.GraphInputIndex = 4;
+ biasInputEdge.ToNodeIndex = 1;
+ biasInputEdge.ToNodeInputIndex = 1;
+ inputEdges.push_back(std::move(biasInputEdge));
+ }
+
+ // Insert the MVN operation into the graph
+ opDescs.push_back(&mvnOpDesc);
+
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
+ intermediateEdge.FromNodeIndex = biasDesc.Desc ? 1 : 0;
+ intermediateEdge.FromNodeOutputIndex = 0;
+ intermediateEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
+ intermediateEdge.ToNodeInputIndex = 0;
+ intermediateEdges.push_back(std::move(intermediateEdge));
+
+ DML_INPUT_GRAPH_EDGE_DESC gammaInputEdge = {};
+ gammaInputEdge.GraphInputIndex = 2;
+ gammaInputEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
+ gammaInputEdge.ToNodeInputIndex = 1;
+ inputEdges.push_back(std::move(gammaInputEdge));
+
+ if (betaDesc.Desc)
+ {
+ DML_INPUT_GRAPH_EDGE_DESC betaInputEdge = {};
+ betaInputEdge.GraphInputIndex = 3;
+ betaInputEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
+ betaInputEdge.ToNodeInputIndex = 2;
+ inputEdges.push_back(std::move(betaInputEdge));
+ }
+
+ DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
+ outputEdge.GraphOutputIndex = 0;
+ outputEdge.FromNodeIndex = biasDesc.Desc ? 2 : 1;
+ outputEdge.FromNodeOutputIndex = 0;
+ outputEdges.push_back(std::move(outputEdge));
+
+ MLOperatorGraphDesc operatorGraphDesc = {};
+ operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size());
+ operatorGraphDesc.inputEdges = inputEdges.data();
+ operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size());
+ operatorGraphDesc.intermediateEdges = intermediateEdges.data();
+ operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size());
+ operatorGraphDesc.outputEdges = outputEdges.data();
+ operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size());
+ operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+
+ SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
+ }
+};
+
+DML_OP_DEFINE_CREATION_FUNCTION(SkipLayerNormalization, DmlOperatorSkipLayerNormalization);
+
+} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index bef3bf66d33b8..78fb9f73391d1 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -100,6 +100,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15);
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization17);
+DML_OP_EXTERN_CREATION_FUNCTION(SkipLayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LRN);
DML_OP_EXTERN_CREATION_FUNCTION(MeanVarianceNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LpNormalization);
@@ -748,6 +749,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
+ {REG_INFO_MS( 1, SkipLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
};
template
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 7df7a51b78f04..1b14aa327169b 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -1405,6 +1405,7 @@ using ShapeInferenceHelper_LRN = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_MeanVarianceNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper;
+using ShapeInferenceHelper_SkipLayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LpNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_RNN = RecurrentHelper;
using ShapeInferenceHelper_GRU = RecurrentHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 8b8f4fa4a4f0a..7dc6096c680cb 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -395,6 +395,7 @@ namespace OperatorHelper
static const int sc_sinceVer_FusedMatMul = 1;
static const int sc_sinceVer_QLinearSigmoid = 1;
static const int sc_sinceVer_Attention = 1;
+ static const int sc_sinceVer_SkipLayerNormalization = 1;
} // namespace MsftOperatorSet1
} // namespace OperatorHelper
diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc
index f48be8b8c8bf8..c90a9f0466d96 100644
--- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc
+++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc
@@ -37,6 +37,7 @@ static void RunTest(
std::vector output_dims = input_dims;
auto rocm_ep = DefaultRocmExecutionProvider();
+ auto dml_ep = DefaultDmlExecutionProvider();
if (!use_float16) {
OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain);
test.AddInput("input", input_dims, input_data);
@@ -55,6 +56,7 @@ static void RunTest(
test.AddOutput("output", output_dims, output_data);
test.Run();
} else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) ||
+ dml_ep != nullptr ||
rocm_ep != nullptr) {
OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain);
test.AddInput("input", input_dims, ToFloat16(input_data));
@@ -73,7 +75,9 @@ static void RunTest(
test.AddOutput("output", output_dims, ToFloat16(output_data));
std::vector> execution_providers;
- if (rocm_ep != nullptr) {
+ if (dml_ep != nullptr) {
+ execution_providers.push_back(DefaultDmlExecutionProvider());
+ } else if (rocm_ep != nullptr) {
execution_providers.push_back(DefaultRocmExecutionProvider());
} else {
execution_providers.push_back(DefaultCudaExecutionProvider());
diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc
index 2ea1f3382357c..634b1565491b7 100644
--- a/onnxruntime/test/providers/provider_test_utils.cc
+++ b/onnxruntime/test/providers/provider_test_utils.cc
@@ -339,8 +339,10 @@ struct TensorCheck {
const bool has_rel_err = params.relative_error_.has_value();
float threshold = 0.001f;
-#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
+#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM)
threshold = 0.005f;
+#elif defined(USE_DML)
+ threshold = 0.008f;
#endif
for (int i = 0; i < size; ++i) {
if (std::isnan(f_expected[i])) {