Skip to content

Commit bcb0656

Browse files
PatriceVignolafuhengwu2021
authored andcommitted
[DML EP] Add SkipLayerNormalization (microsoft#13849)
### Description Add SkipLayerNormalization for the DML EP
1 parent 1211c24 commit bcb0656

File tree

8 files changed

+179
-4
lines changed

8 files changed

+179
-4
lines changed

docs/OperatorKernels.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,7 @@ Do not modify directly.*
11281128
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
11291129
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
11301130
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
1131+
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)|
11311132
| |
11321133
| |
11331134
|**Operator Domain:** *com.microsoft.dml*||||

onnxruntime/core/optimizer/graph_transformer_utils.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
271271
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(cpu_cuda_rocm_eps));
272272

273273
transformers.emplace_back(std::make_unique<MatmulTransposeFusion>(cpu_cuda_rocm_eps));
274-
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_cuda_rocm_eps));
274+
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_cuda_dml_rocm_eps));
275275
transformers.emplace_back(std::make_unique<BiasSoftmaxFusion>(cpu_cuda_rocm_eps));
276276
transformers.emplace_back(std::make_unique<BiasDropoutFusion>(cuda_rocm_eps));
277277
#ifdef ENABLE_TRAINING
@@ -280,7 +280,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
280280
transformers.emplace_back(std::make_unique<SceLossGradBiasFusion>(cpu_cuda_rocm_eps));
281281
#endif
282282

283-
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_rocm_eps));
283+
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_dml_rocm_eps));
284284

285285
transformers.emplace_back(std::make_unique<FastGeluFusion>(cpu_cuda_rocm_eps));
286286
transformers.emplace_back(std::make_unique<QuickGeluFusion>(cpu_cuda_rocm_eps));
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "precomp.h"
5+
6+
namespace Dml
7+
{
8+
9+
class DmlOperatorSkipLayerNormalization : public DmlOperator
10+
{
11+
public:
12+
DmlOperatorSkipLayerNormalization(const MLOperatorKernelCreationContext& kernelCreationContext)
13+
: DmlOperator(kernelCreationContext)
14+
{
15+
std::vector<std::optional<uint32_t>> kernelInputIndices = {0, 1, 2, 3, 4};
16+
17+
DmlOperator::Initialize(
18+
kernelCreationContext,
19+
kernelInputIndices,
20+
std::nullopt,
21+
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0),
22+
std::nullopt,
23+
kernelCreationContext.GetTensorShapeDescription().GetInputTensorDimensionCount(0));
24+
25+
const float epsilon = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Epsilon, DefaultEpsilon);
26+
27+
int32_t onnxAxis = kernelCreationContext.GetOptionalAttribute<int32_t>(AttrName::Axis, -1);
28+
uint32_t inputDimCount = kernelCreationContext.GetTensorShapeDescription().GetInputTensorDimensionCount(0);
29+
onnxAxis = OperatorHelper::HandleNegativeAxis(onnxAxis, inputDimCount);
30+
std::vector<uint32_t> onnxAxes(inputDimCount - onnxAxis);
31+
std::iota(onnxAxes.begin(), onnxAxes.end(), onnxAxis);
32+
33+
assert(m_inputTensorDescs.size() == 5);
34+
assert(m_outputTensorDescs.size() == 1);
35+
36+
auto inputDesc = m_inputTensorDescs[0].GetDmlDesc();
37+
auto skipDesc = m_inputTensorDescs[1].GetDmlDesc();
38+
auto gammaDesc = m_inputTensorDescs[2].GetDmlDesc();
39+
auto betaDesc = m_inputTensorDescs[3].GetDmlDesc();
40+
auto biasDesc = m_inputTensorDescs[4].GetDmlDesc();
41+
auto outputDesc = m_outputTensorDescs[0].GetDmlDesc();
42+
43+
TensorDesc inputSkipBiasTensorDesc(m_inputTensorDescs[0].GetDmlDataType(), m_inputTensorDescs[0].GetSizes());
44+
DML_TENSOR_DESC inputSkipBiasDmlTensorDesc = inputSkipBiasTensorDesc.GetDmlDesc();
45+
46+
DML_ELEMENT_WISE_ADD_OPERATOR_DESC inputSkipAddDesc = {};
47+
inputSkipAddDesc.ATensor = &inputDesc;
48+
inputSkipAddDesc.BTensor = &skipDesc;
49+
inputSkipAddDesc.OutputTensor = &inputSkipBiasDmlTensorDesc;
50+
DML_OPERATOR_DESC inputSkipAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipAddDesc };
51+
52+
DML_ELEMENT_WISE_ADD_OPERATOR_DESC inputSkipBiasAddDesc = {};
53+
inputSkipBiasAddDesc.ATensor = &inputSkipBiasDmlTensorDesc;
54+
inputSkipBiasAddDesc.BTensor = &biasDesc;
55+
inputSkipBiasAddDesc.OutputTensor = &inputSkipBiasDmlTensorDesc;
56+
DML_OPERATOR_DESC inputSkipBiasAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipBiasAddDesc };
57+
58+
DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC mvnDesc = {};
59+
mvnDesc.InputTensor = &inputSkipBiasDmlTensorDesc;
60+
mvnDesc.ScaleTensor = &gammaDesc;
61+
mvnDesc.BiasTensor = betaDesc.Desc ? &betaDesc : nullptr;
62+
mvnDesc.OutputTensor = &outputDesc;
63+
mvnDesc.Axes = onnxAxes.data();
64+
mvnDesc.AxisCount = gsl::narrow_cast<uint32_t>(onnxAxes.size());
65+
mvnDesc.NormalizeVariance = true;
66+
mvnDesc.Epsilon = epsilon;
67+
mvnDesc.FusedActivation = nullptr;
68+
DML_OPERATOR_DESC mvnOpDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &mvnDesc };
69+
70+
// Construct the graph
71+
std::vector<const DML_OPERATOR_DESC*> opDescs;
72+
opDescs.reserve(3);
73+
74+
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
75+
inputEdges.reserve(5);
76+
77+
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
78+
intermediateEdges.reserve(2);
79+
80+
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
81+
outputEdges.reserve(1);
82+
83+
// Insert the Input + Skip operation into the graph
84+
opDescs.push_back(&inputSkipAddOpDesc);
85+
86+
DML_INPUT_GRAPH_EDGE_DESC dataInputEdge = {};
87+
dataInputEdge.GraphInputIndex = 0;
88+
dataInputEdge.ToNodeIndex = 0;
89+
dataInputEdge.ToNodeInputIndex = 0;
90+
inputEdges.push_back(std::move(dataInputEdge));
91+
92+
DML_INPUT_GRAPH_EDGE_DESC skipInputEdge = {};
93+
skipInputEdge.GraphInputIndex = 1;
94+
skipInputEdge.ToNodeIndex = 0;
95+
skipInputEdge.ToNodeInputIndex = 1;
96+
inputEdges.push_back(std::move(skipInputEdge));
97+
98+
// Insert the InputSkip + Bias operation into the graph
99+
if (biasDesc.Desc)
100+
{
101+
opDescs.push_back(&inputSkipBiasAddOpDesc);
102+
103+
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
104+
intermediateEdge.FromNodeIndex = 0;
105+
intermediateEdge.FromNodeOutputIndex = 0;
106+
intermediateEdge.ToNodeIndex = 1;
107+
intermediateEdge.ToNodeInputIndex = 0;
108+
intermediateEdges.push_back(std::move(intermediateEdge));
109+
110+
DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {};
111+
biasInputEdge.GraphInputIndex = 4;
112+
biasInputEdge.ToNodeIndex = 1;
113+
biasInputEdge.ToNodeInputIndex = 1;
114+
inputEdges.push_back(std::move(biasInputEdge));
115+
}
116+
117+
// Insert the MVN operation into the graph
118+
opDescs.push_back(&mvnOpDesc);
119+
120+
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
121+
intermediateEdge.FromNodeIndex = biasDesc.Desc ? 1 : 0;
122+
intermediateEdge.FromNodeOutputIndex = 0;
123+
intermediateEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
124+
intermediateEdge.ToNodeInputIndex = 0;
125+
intermediateEdges.push_back(std::move(intermediateEdge));
126+
127+
DML_INPUT_GRAPH_EDGE_DESC gammaInputEdge = {};
128+
gammaInputEdge.GraphInputIndex = 2;
129+
gammaInputEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
130+
gammaInputEdge.ToNodeInputIndex = 1;
131+
inputEdges.push_back(std::move(gammaInputEdge));
132+
133+
if (betaDesc.Desc)
134+
{
135+
DML_INPUT_GRAPH_EDGE_DESC betaInputEdge = {};
136+
betaInputEdge.GraphInputIndex = 3;
137+
betaInputEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
138+
betaInputEdge.ToNodeInputIndex = 2;
139+
inputEdges.push_back(std::move(betaInputEdge));
140+
}
141+
142+
DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
143+
outputEdge.GraphOutputIndex = 0;
144+
outputEdge.FromNodeIndex = biasDesc.Desc ? 2 : 1;
145+
outputEdge.FromNodeOutputIndex = 0;
146+
outputEdges.push_back(std::move(outputEdge));
147+
148+
MLOperatorGraphDesc operatorGraphDesc = {};
149+
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
150+
operatorGraphDesc.inputEdges = inputEdges.data();
151+
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
152+
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
153+
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
154+
operatorGraphDesc.outputEdges = outputEdges.data();
155+
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
156+
operatorGraphDesc.nodesAsOpDesc = opDescs.data();
157+
158+
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
159+
}
160+
};
161+
162+
DML_OP_DEFINE_CREATION_FUNCTION(SkipLayerNormalization, DmlOperatorSkipLayerNormalization);
163+
164+
} // namespace Dml

onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
100100
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15);
101101
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization);
102102
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization17);
103+
DML_OP_EXTERN_CREATION_FUNCTION(SkipLayerNormalization);
103104
DML_OP_EXTERN_CREATION_FUNCTION(LRN);
104105
DML_OP_EXTERN_CREATION_FUNCTION(MeanVarianceNormalization);
105106
DML_OP_EXTERN_CREATION_FUNCTION(LpNormalization);
@@ -748,6 +749,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
748749
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
749750
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
750751
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
752+
{REG_INFO_MS( 1, SkipLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
751753
};
752754

753755
template<typename T>

onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,7 @@ using ShapeInferenceHelper_LRN = GetOutputShapeAsInputShapeHelper;
14051405
using ShapeInferenceHelper_MeanVarianceNormalization = GetOutputShapeAsInputShapeHelper;
14061406
using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper;
14071407
using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper;
1408+
using ShapeInferenceHelper_SkipLayerNormalization = GetOutputShapeAsInputShapeHelper;
14081409
using ShapeInferenceHelper_LpNormalization = GetOutputShapeAsInputShapeHelper;
14091410
using ShapeInferenceHelper_RNN = RecurrentHelper;
14101411
using ShapeInferenceHelper_GRU = RecurrentHelper;

onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ namespace OperatorHelper
395395
static const int sc_sinceVer_FusedMatMul = 1;
396396
static const int sc_sinceVer_QLinearSigmoid = 1;
397397
static const int sc_sinceVer_Attention = 1;
398+
static const int sc_sinceVer_SkipLayerNormalization = 1;
398399
} // namespace MsftOperatorSet1
399400

400401
} // namespace OperatorHelper

onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ static void RunTest(
3737
std::vector<int64_t> output_dims = input_dims;
3838

3939
auto rocm_ep = DefaultRocmExecutionProvider();
40+
auto dml_ep = DefaultDmlExecutionProvider();
4041
if (!use_float16) {
4142
OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain);
4243
test.AddInput<float>("input", input_dims, input_data);
@@ -55,6 +56,7 @@ static void RunTest(
5556
test.AddOutput<float>("output", output_dims, output_data);
5657
test.Run();
5758
} else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) ||
59+
dml_ep != nullptr ||
5860
rocm_ep != nullptr) {
5961
OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain);
6062
test.AddInput<MLFloat16>("input", input_dims, ToFloat16(input_data));
@@ -73,7 +75,9 @@ static void RunTest(
7375
test.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
7476

7577
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
76-
if (rocm_ep != nullptr) {
78+
if (dml_ep != nullptr) {
79+
execution_providers.push_back(DefaultDmlExecutionProvider());
80+
} else if (rocm_ep != nullptr) {
7781
execution_providers.push_back(DefaultRocmExecutionProvider());
7882
} else {
7983
execution_providers.push_back(DefaultCudaExecutionProvider());

onnxruntime/test/providers/provider_test_utils.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,10 @@ struct TensorCheck<MLFloat16> {
339339
const bool has_rel_err = params.relative_error_.has_value();
340340

341341
float threshold = 0.001f;
342-
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
342+
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM)
343343
threshold = 0.005f;
344+
#elif defined(USE_DML)
345+
threshold = 0.008f;
344346
#endif
345347
for (int i = 0; i < size; ++i) {
346348
if (std::isnan(f_expected[i])) {

0 commit comments

Comments
 (0)