Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,12 @@ Do not modify directly.*
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Affine|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|And|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)<br/> **T1** = tensor(bool)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|||10|**T** = tensor(double), tensor(float), tensor(float16)|
|||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)|
Expand Down
63 changes: 60 additions & 3 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,13 @@
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum);

class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin);

// OpSet 13
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add);
Expand Down Expand Up @@ -1199,6 +1206,13 @@
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin);

// OpSet 14
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu);
Expand Down Expand Up @@ -1640,6 +1654,9 @@
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,

Check warning on line 1657 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1657: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax)>,

Check warning on line 1658 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1658: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax)>,

Check warning on line 1659 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1659: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL1)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL1)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL1)>,
Expand Down Expand Up @@ -1822,9 +1839,6 @@
19, IsInf)>,

// opset 11
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
Expand Down Expand Up @@ -1916,6 +1930,13 @@
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax)>,

Check warning on line 1933 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1933: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax)>,

Check warning on line 1934 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1934: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax)>,

Check warning on line 1935 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1935: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin)>,

Check warning on line 1936 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1936: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin)>,

Check warning on line 1937 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1937: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin)>,

Check warning on line 1938 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1938: Lines should be <= 120 characters long [whitespace/line_length] [2]

// OpSet 13
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add)>,
Expand Down Expand Up @@ -2150,6 +2171,13 @@
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax)>,

Check warning on line 2174 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:2174: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin)>,

// OpSet 14
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu)>,
Expand Down Expand Up @@ -2566,6 +2594,32 @@
return false;
}

static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) {
// Opset 12 introduced the attribute "select_last_index"
if (node.SinceVersion() >= 12) {
const auto& node_attributes = node.GetAttributes();

for (auto& attr : node_attributes) {
auto& attr_name = attr.first;
auto& attr_value = attr.second;

// CuDNN doesn't support picking the last index in case of encountering
// duplicate max values.
// CuDNN's API doc doesn't mention what happens in case duplicates are encountered,
// but based on testing, the results seem to indicate a "stable" implementation
// (i.e.) relative ordering is preserved which is the expected behavior when the
// attribute takes on the default value (most common use-case for this operator).
if ("select_last_index" == attr_name) {
if (attr_value.i() != 0) {
return true;
}
}
}
}

return false;
}

std::unique_ptr<onnxruntime::IDataTransfer> CUDAExecutionProvider::GetDataTransfer() const {
return std::make_unique<onnxruntime::GPUDataTransfer>();
}
Expand Down Expand Up @@ -2615,6 +2669,9 @@
} else if ("ConvTranspose" == node.OpType()) {
not_supported = ConvTransposeNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred());
force_inside = !not_supported;
} else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) {
not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node);
force_inside = !not_supported;
} else if ("Cast" == node.OpType()) {
not_supported = CastNeedFallbackToCPU(node);
// cast is not compute heavy, and may be placed outside
Expand Down
28 changes: 16 additions & 12 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ using namespace onnxruntime::common;
namespace onnxruntime {
namespace cuda {

#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \
#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
1, end, \
begin, end, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);

#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \
#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
Expand All @@ -37,8 +37,13 @@ namespace cuda {
name<T>);

#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \
REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur)
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur)

#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13)

// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
template <bool allow_multi_axes>
Expand Down Expand Up @@ -829,14 +834,13 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, CUDNN_REDUCE_TENSOR_NO

} // namespace ReductionOps

// CUDA ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, float, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, double, 11)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, MLFloat16)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, float)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, double)

REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, float, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, double, 11)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, MLFloat16)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, float)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, double)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18)
Expand Down
20 changes: 18 additions & 2 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,15 @@ class ReduceKernel : public CudaKernel, public ReduceKernelBase<allow_multi_axes
template <typename T>
class ArgMax final : public ReduceKernel<false> {
public:
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {}
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {
// The following is just a safety check.
// The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMax
// nodes with select_last_index == 1 to the CUDA EP.
int64_t select_last_index = 0;
if (info.GetAttr<int64_t>("select_last_index", &select_last_index).IsOK()) {
ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
}
}

Status ComputeInternal(OpKernelContext* ctx) const override {
return ComputeImpl<T, CUDNN_REDUCE_TENSOR_FLATTENED_INDICES>(ctx, CUDNN_REDUCE_TENSOR_MAX);
Expand All @@ -98,7 +106,15 @@ class ArgMax final : public ReduceKernel<false> {
template <typename T>
class ArgMin final : public ReduceKernel<false> {
public:
ArgMin(const OpKernelInfo& info) : ReduceKernel<false>(info) {}
ArgMin(const OpKernelInfo& info) : ReduceKernel<false>(info) {
// The following is just a safety check.
// The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMin
// nodes with select_last_index == 1 to the CUDA EP.
int64_t select_last_index = 0;
if (info.GetAttr<int64_t>("select_last_index", &select_last_index).IsOK()) {
ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
}
}

Status ComputeInternal(OpKernelContext* ctx) const override {
return ComputeImpl<T, CUDNN_REDUCE_TENSOR_FLATTENED_INDICES>(ctx, CUDNN_REDUCE_TENSOR_MIN);
Expand Down
28 changes: 16 additions & 12 deletions onnxruntime/core/providers/rocm/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ using namespace onnxruntime::common;
namespace onnxruntime {
namespace rocm {

#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \
#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
1, end, \
begin, end, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);

#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \
#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
Expand All @@ -37,8 +37,13 @@ namespace rocm {
name<T>);

#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \
REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur)
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur)

#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13)

// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
template <bool allow_multi_axes>
Expand Down Expand Up @@ -830,14 +835,13 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, MIOPEN_REDUCE_TENSOR_N

} // namespace ReductionOps

// ROCM ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, float, 11)
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, double, 11)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, MLFloat16)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, float)
// REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, double)

REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, float, 11)
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, double, 11)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, MLFloat16)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, float)
// REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, double)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18)
Expand Down
Loading
Loading