Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,7 @@ Do not modify directly.*
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)|
| |
| |
|**Operator Domain:** *com.microsoft.dml*||||
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(cpu_cuda_rocm_eps));

transformers.emplace_back(std::make_unique<MatmulTransposeFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<BiasSoftmaxFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<BiasDropoutFusion>(cuda_rocm_eps));
#ifdef ENABLE_TRAINING
Expand All @@ -280,7 +280,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<SceLossGradBiasFusion>(cpu_cuda_rocm_eps));
#endif

transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_dml_rocm_eps));

transformers.emplace_back(std::make_unique<FastGeluFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<QuickGeluFusion>(cpu_cuda_rocm_eps));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "precomp.h"

namespace Dml
{

class DmlOperatorSkipLayerNormalization : public DmlOperator
{
public:
DmlOperatorSkipLayerNormalization(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
std::vector<std::optional<uint32_t>> kernelInputIndices = {0, 1, 2, 3, 4};

DmlOperator::Initialize(
kernelCreationContext,
kernelInputIndices,
std::nullopt,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0),
std::nullopt,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorDimensionCount(0));

const float epsilon = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Epsilon, DefaultEpsilon);

int32_t onnxAxis = kernelCreationContext.GetOptionalAttribute<int32_t>(AttrName::Axis, -1);
uint32_t inputDimCount = kernelCreationContext.GetTensorShapeDescription().GetInputTensorDimensionCount(0);
onnxAxis = OperatorHelper::HandleNegativeAxis(onnxAxis, inputDimCount);
std::vector<uint32_t> onnxAxes(inputDimCount - onnxAxis);
std::iota(onnxAxes.begin(), onnxAxes.end(), onnxAxis);

assert(m_inputTensorDescs.size() == 5);
assert(m_outputTensorDescs.size() == 1);

auto inputDesc = m_inputTensorDescs[0].GetDmlDesc();
auto skipDesc = m_inputTensorDescs[1].GetDmlDesc();
auto gammaDesc = m_inputTensorDescs[2].GetDmlDesc();
auto betaDesc = m_inputTensorDescs[3].GetDmlDesc();
auto biasDesc = m_inputTensorDescs[4].GetDmlDesc();
auto outputDesc = m_outputTensorDescs[0].GetDmlDesc();

TensorDesc inputSkipBiasTensorDesc(m_inputTensorDescs[0].GetDmlDataType(), m_inputTensorDescs[0].GetSizes());
DML_TENSOR_DESC inputSkipBiasDmlTensorDesc = inputSkipBiasTensorDesc.GetDmlDesc();

DML_ELEMENT_WISE_ADD_OPERATOR_DESC inputSkipAddDesc = {};
inputSkipAddDesc.ATensor = &inputDesc;
inputSkipAddDesc.BTensor = &skipDesc;
inputSkipAddDesc.OutputTensor = &inputSkipBiasDmlTensorDesc;
DML_OPERATOR_DESC inputSkipAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipAddDesc };

DML_ELEMENT_WISE_ADD_OPERATOR_DESC inputSkipBiasAddDesc = {};
inputSkipBiasAddDesc.ATensor = &inputSkipBiasDmlTensorDesc;
inputSkipBiasAddDesc.BTensor = &biasDesc;
inputSkipBiasAddDesc.OutputTensor = &inputSkipBiasDmlTensorDesc;
DML_OPERATOR_DESC inputSkipBiasAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipBiasAddDesc };

DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC mvnDesc = {};
mvnDesc.InputTensor = &inputSkipBiasDmlTensorDesc;
mvnDesc.ScaleTensor = &gammaDesc;
mvnDesc.BiasTensor = betaDesc.Desc ? &betaDesc : nullptr;
mvnDesc.OutputTensor = &outputDesc;
mvnDesc.Axes = onnxAxes.data();
mvnDesc.AxisCount = gsl::narrow_cast<uint32_t>(onnxAxes.size());
mvnDesc.NormalizeVariance = true;
mvnDesc.Epsilon = epsilon;
mvnDesc.FusedActivation = nullptr;
DML_OPERATOR_DESC mvnOpDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &mvnDesc };

// Construct the graph
std::vector<const DML_OPERATOR_DESC*> opDescs;
opDescs.reserve(3);

std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
inputEdges.reserve(5);

std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
intermediateEdges.reserve(2);

std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
outputEdges.reserve(1);

// Insert the Input + Skip operation into the graph
opDescs.push_back(&inputSkipAddOpDesc);

DML_INPUT_GRAPH_EDGE_DESC dataInputEdge = {};
dataInputEdge.GraphInputIndex = 0;
dataInputEdge.ToNodeIndex = 0;
dataInputEdge.ToNodeInputIndex = 0;
inputEdges.push_back(std::move(dataInputEdge));

DML_INPUT_GRAPH_EDGE_DESC skipInputEdge = {};
skipInputEdge.GraphInputIndex = 1;
skipInputEdge.ToNodeIndex = 0;
skipInputEdge.ToNodeInputIndex = 1;
inputEdges.push_back(std::move(skipInputEdge));

// Insert the InputSkip + Bias operation into the graph
if (biasDesc.Desc)
{
opDescs.push_back(&inputSkipBiasAddOpDesc);

DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
intermediateEdge.FromNodeIndex = 0;
intermediateEdge.FromNodeOutputIndex = 0;
intermediateEdge.ToNodeIndex = 1;
intermediateEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(std::move(intermediateEdge));

DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {};
biasInputEdge.GraphInputIndex = 4;
biasInputEdge.ToNodeIndex = 1;
biasInputEdge.ToNodeInputIndex = 1;
inputEdges.push_back(std::move(biasInputEdge));
}

// Insert the MVN operation into the graph
opDescs.push_back(&mvnOpDesc);

DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
intermediateEdge.FromNodeIndex = biasDesc.Desc ? 1 : 0;
intermediateEdge.FromNodeOutputIndex = 0;
intermediateEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
intermediateEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(std::move(intermediateEdge));

DML_INPUT_GRAPH_EDGE_DESC gammaInputEdge = {};
gammaInputEdge.GraphInputIndex = 2;
gammaInputEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
gammaInputEdge.ToNodeInputIndex = 1;
inputEdges.push_back(std::move(gammaInputEdge));

if (betaDesc.Desc)
{
DML_INPUT_GRAPH_EDGE_DESC betaInputEdge = {};
betaInputEdge.GraphInputIndex = 3;
betaInputEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
betaInputEdge.ToNodeInputIndex = 2;
inputEdges.push_back(std::move(betaInputEdge));
}

DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
outputEdge.GraphOutputIndex = 0;
outputEdge.FromNodeIndex = biasDesc.Desc ? 2 : 1;
outputEdge.FromNodeOutputIndex = 0;
outputEdges.push_back(std::move(outputEdge));

MLOperatorGraphDesc operatorGraphDesc = {};
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
operatorGraphDesc.inputEdges = inputEdges.data();
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
operatorGraphDesc.outputEdges = outputEdges.data();
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
operatorGraphDesc.nodesAsOpDesc = opDescs.data();

SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
}
};

DML_OP_DEFINE_CREATION_FUNCTION(SkipLayerNormalization, DmlOperatorSkipLayerNormalization);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15);
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization17);
DML_OP_EXTERN_CREATION_FUNCTION(SkipLayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LRN);
DML_OP_EXTERN_CREATION_FUNCTION(MeanVarianceNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LpNormalization);
Expand Down Expand Up @@ -748,6 +749,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO_MS( 1, SkipLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
};

template<typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,7 @@ using ShapeInferenceHelper_LRN = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_MeanVarianceNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_SkipLayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LpNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_RNN = RecurrentHelper;
using ShapeInferenceHelper_GRU = RecurrentHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ namespace OperatorHelper
static const int sc_sinceVer_FusedMatMul = 1;
static const int sc_sinceVer_QLinearSigmoid = 1;
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_SkipLayerNormalization = 1;
} // namespace MsftOperatorSet1

} // namespace OperatorHelper
6 changes: 5 additions & 1 deletion onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ static void RunTest(
std::vector<int64_t> output_dims = input_dims;

auto rocm_ep = DefaultRocmExecutionProvider();
auto dml_ep = DefaultDmlExecutionProvider();
if (!use_float16) {
OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", input_dims, input_data);
Expand All @@ -55,6 +56,7 @@ static void RunTest(
test.AddOutput<float>("output", output_dims, output_data);
test.Run();
} else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) ||
dml_ep != nullptr ||
rocm_ep != nullptr) {
OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain);
test.AddInput<MLFloat16>("input", input_dims, ToFloat16(input_data));
Expand All @@ -73,7 +75,9 @@ static void RunTest(
test.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (rocm_ep != nullptr) {
if (dml_ep != nullptr) {
execution_providers.push_back(DefaultDmlExecutionProvider());
} else if (rocm_ep != nullptr) {
execution_providers.push_back(DefaultRocmExecutionProvider());
} else {
execution_providers.push_back(DefaultCudaExecutionProvider());
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/test/providers/provider_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,10 @@ struct TensorCheck<MLFloat16> {
const bool has_rel_err = params.relative_error_.has_value();

float threshold = 0.001f;
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM)
threshold = 0.005f;
#elif defined(USE_DML)
threshold = 0.008f;
#endif
for (int i = 0; i < size; ++i) {
if (std::isnan(f_expected[i])) {
Expand Down