Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,9 @@ Do not modify directly.*
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Affine|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|And|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)<br/> **T1** = tensor(bool)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||11|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
Expand Down
53 changes: 47 additions & 6 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -743,9 +743,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK);

// opset 11
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMin);
Expand Down Expand Up @@ -898,6 +898,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax);

//OpSet 13
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow);
Expand Down Expand Up @@ -1112,6 +1115,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax);

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add);
Expand Down Expand Up @@ -1592,9 +1598,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, uint8_t, DequantizeLinear)>,

// opset 11
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMin)>,
Expand Down Expand Up @@ -1743,6 +1749,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax)>,

// OpSet 13
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
Expand Down Expand Up @@ -1957,6 +1966,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax)>,

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add)>,
Expand Down Expand Up @@ -2137,6 +2149,32 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node,
return false;
}

static bool ArgMaxNeedFallbackToCPU(const onnxruntime::Node& node) {
// Opset 12 introduced the attribute "select_last_index"
if (node.SinceVersion() >= 12) {
const auto& node_attributes = node.GetAttributes();

for (auto& attr : node_attributes) {
auto& attr_name = attr.first;
auto& attr_value = attr.second;

// CuDNN doesn't support picking the last index in case of encountering
// duplicate max values.
// CuDNN's API doc doesn't mention what happens in case duplicates are encountered,
// but based on testing, the results seem to indicate a "stable" implementation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's verify this with our Nvidia POC.

// (i.e.) relative ordering is preserved which is the expected behavior when the
// attribute takes on the default value (most commong use-case for this operator).
if ("select_last_index" == attr_name) {
if (attr_value.i() != 0) {
return true;
}
}
}
}

return false;
}

static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) {
const auto& node_attributes = node.GetAttributes();
// Check attributes
Expand Down Expand Up @@ -2257,6 +2295,9 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
} else if ("ConvTranspose" == node.OpType()) {
not_supported = ConvTransposeNeedFallbackToCPU(node);
force_inside = !not_supported;
} else if ("ArgMax" == node.OpType()) {
not_supported = ArgMaxNeedFallbackToCPU(node);
force_inside = !not_supported;
} else if ("Cast" == node.OpType()) {
not_supported = CastNeedFallbackToCPU(node);
// cast is not compute heavy, and may be placed outside
Expand Down
79 changes: 59 additions & 20 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ namespace cuda {
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);

// CUDA ArgMax/ArgMin doesn't have OpSet12 implementation (with select_last_index attr), keep it in OpSet11 for now.
#define REGISTER_KERNEL_TYPED_11(name, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
Expand All @@ -122,6 +121,40 @@ namespace cuda {
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);

#define REGISTER_ARGMAX_KERNEL_TYPED_13(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
ArgMax, \
kOnnxDomain, \
1, 10, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ArgMax<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
ArgMax, \
kOnnxDomain, \
11, 11, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ArgMax<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
ArgMax, \
kOnnxDomain, \
12, 12, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ArgMax<T>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ArgMax, \
kOnnxDomain, \
13, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ArgMax<T>);

// Register with the latest version 13
#define REGISTER_KERNEL_TYPED_13(name, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Expand Down Expand Up @@ -808,24 +841,24 @@ Status ReduceKernel<allow_multi_axes>::ComputeImpl(OpKernelContext* ctx, cudnnRe
cudnnDataType_t cudnn_type_X = CUDNN_DATA_FLOAT; \
IAllocatorUniquePtr<float> temp_X = GetScratchBuffer<float>(input_count); \
Impl_Cast<CudaT, float>(Stream(), reinterpret_cast<const CudaT*>(X->template Data<T>()), temp_X.get(), X->Shape().Size()); \
\
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, cudnn_type_X, CUDNN_REDUCE_TENSOR_NO_INDICES)); \
ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_cudnn, cudnn_type_X)); \
ORT_RETURN_IF_ERROR(output_tensor.Set(output_dims_cudnn, cudnn_type_X)); \
CUDNN_RETURN_IF_ERROR( \
cudnnGetReductionIndicesSize(CudnnHandle(), reduce_desc, input_tensor, output_tensor, &indices_bytes)); \
CUDNN_RETURN_IF_ERROR( \
cudnnGetReductionWorkspaceSize(CudnnHandle(), reduce_desc, input_tensor, output_tensor, &workspace_bytes)); \
IAllocatorUniquePtr<uint32_t> indices_cuda = GetScratchBuffer<uint32_t>(indices_bytes); \
IAllocatorUniquePtr<CudaT> workspace_cuda = GetScratchBuffer<CudaT>(workspace_bytes); \
\
const auto one = Consts<float>::One; \
const auto zero = Consts<float>::Zero; \
auto temp_Y = GetScratchBuffer<float>(output_count); \
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(CudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes, \
workspace_cuda.get(), workspace_bytes, &one, input_tensor, temp_X.get(), \
&zero, output_tensor, temp_Y.get())); \
\
\
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, cudnn_type_X, CUDNN_REDUCE_TENSOR_NO_INDICES)); \
ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_cudnn, cudnn_type_X)); \
ORT_RETURN_IF_ERROR(output_tensor.Set(output_dims_cudnn, cudnn_type_X)); \
CUDNN_RETURN_IF_ERROR( \
cudnnGetReductionIndicesSize(CudnnHandle(), reduce_desc, input_tensor, output_tensor, &indices_bytes)); \
CUDNN_RETURN_IF_ERROR( \
cudnnGetReductionWorkspaceSize(CudnnHandle(), reduce_desc, input_tensor, output_tensor, &workspace_bytes)); \
IAllocatorUniquePtr<uint32_t> indices_cuda = GetScratchBuffer<uint32_t>(indices_bytes); \
IAllocatorUniquePtr<CudaT> workspace_cuda = GetScratchBuffer<CudaT>(workspace_bytes); \
\
const auto one = Consts<float>::One; \
const auto zero = Consts<float>::Zero; \
auto temp_Y = GetScratchBuffer<float>(output_count); \
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(CudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes, \
workspace_cuda.get(), workspace_bytes, &one, input_tensor, temp_X.get(), \
&zero, output_tensor, temp_Y.get())); \
\
Impl_Cast<float, CudaT>(Stream(), temp_Y.get(), reinterpret_cast<CudaT*>(Y->template MutableData<T>()), output_count); \
\
return Status::OK(); \
Expand Down Expand Up @@ -1015,8 +1048,14 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, CUDNN_REDUCE_TENSOR_NO
REGISTER_KERNEL_TYPED_11(name, float) \
REGISTER_KERNEL_TYPED_11(name, double)

REGISTER_KERNEL_HFD_11(ArgMax)
REGISTER_KERNEL_HFD_11(ArgMin)

// If supporting select_last_index == 1, please remove
// logic in ArgMaxNeedFallbackToCPU() in cuda_execution_provider.cc
REGISTER_ARGMAX_KERNEL_TYPED_13(MLFloat16)
REGISTER_ARGMAX_KERNEL_TYPED_13(float)
REGISTER_ARGMAX_KERNEL_TYPED_13(double)

REGISTER_KERNEL_HFD(ReduceL1)
REGISTER_KERNEL_HFD(ReduceL2)

Expand Down
10 changes: 9 additions & 1 deletion onnxruntime/core/providers/cuda/reduction/reduction_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,15 @@ class ReduceKernel : public CudaKernel, public ReduceKernelBase<allow_multi_axes
template <typename T>
class ArgMax final : public ReduceKernel<false> {
public:
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {}
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {
// The following is just a safety check.
// The logic in ArgMaxNeedFallbackToCPU() makes sure to not assign ArgMax
// nodes with select_last_index == 1 to the CUDA EP.
int64_t select_last_index = 0;
if (info.GetAttr<int64_t>("select_last_index", &select_last_index).IsOK()) {
ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
}
}

Status ComputeInternal(OpKernelContext* ctx) const override {
return ComputeImpl<T, CUDNN_REDUCE_TENSOR_FLATTENED_INDICES>(ctx, CUDNN_REDUCE_TENSOR_MAX);
Expand Down
Loading