Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions src/frontends/onnx/frontend/src/utils/arg_min_max_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ ArgMinMaxFactory::ArgMinMaxFactory(const Node& node)
m_axis{node.get_attribute_value<std::int64_t>("axis", 0)},
m_select_last_index{node.get_attribute_value<std::int64_t>("select_last_index", 0)} {}

std::shared_ptr<ov::Node> ArgMinMaxFactory::make_arg_max() const {
ov::Output<ov::Node> ArgMinMaxFactory::make_arg_max() const {
return make_topk_subgraph(v11::TopK::Mode::MAX);
}

std::shared_ptr<ov::Node> ArgMinMaxFactory::make_arg_min() const {
ov::Output<ov::Node> ArgMinMaxFactory::make_arg_min() const {
return make_topk_subgraph(v11::TopK::Mode::MIN);
}

std::shared_ptr<ov::Node> ArgMinMaxFactory::make_topk_subgraph(v11::TopK::Mode mode) const {
ov::Output<ov::Node> ArgMinMaxFactory::make_topk_subgraph(v11::TopK::Mode mode) const {
const auto k_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {1});

if (m_select_last_index == 1) {
Expand Down Expand Up @@ -70,40 +70,50 @@ std::shared_ptr<ov::Node> ArgMinMaxFactory::make_topk_subgraph(v11::TopK::Mode m
const auto axis_node = v0::Constant::create(ov::element::i64, ov::Shape{1}, {normalized_axis});
const auto reverse = std::make_shared<v1::Reverse>(m_input_node, axis_node, v1::Reverse::Mode::INDEX);

const auto topk = std::make_shared<v11::TopK>(reverse, k_node, normalized_axis, mode, v1::TopK::SortType::NONE);
const auto topk = std::make_shared<v11::TopK>(reverse,
k_node,
normalized_axis,
mode,
v1::TopK::SortType::SORT_VALUES,
element::i64,
true);

const auto data_shape = std::make_shared<v0::ShapeOf>(m_input_node);
const auto dims_on_axis =
std::make_shared<v1::Gather>(data_shape,
axis_node,
v0::Constant::create(ov::element::i64, ov::Shape{}, {0}));

const auto res_index =
std::make_shared<v1::Subtract>(dims_on_axis,
std::make_shared<v0::Convert>(topk->output(1), ov::element::i64));
const auto res_index = std::make_shared<v1::Subtract>(dims_on_axis, topk->output(1));
const auto result =
std::make_shared<v1::Subtract>(res_index, v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}));

if (m_keep_dims == 0) {
const auto axis_to_remove = v0::Constant::create(ov::element::u64, ov::Shape{}, {topk->get_axis()});

return std::make_shared<v0::Squeeze>(result, axis_to_remove);
return {std::make_shared<v0::Squeeze>(result, axis_to_remove)};
}

return result;
return {result};
}

const auto topk = std::make_shared<v11::TopK>(m_input_node, k_node, m_axis, mode, v11::TopK::SortType::NONE);
const auto topk = std::make_shared<v11::TopK>(m_input_node,
k_node,
m_axis,
mode,
v11::TopK::SortType::SORT_VALUES,
element::i64,
true);

const auto result = std::make_shared<v0::Convert>(topk->output(1), ov::element::i64);
ov::Output<ov::Node> result = topk->output(1);

if (m_keep_dims == 0) {
const auto axis_to_remove = v0::Constant::create(ov::element::u64, ov::Shape{}, {topk->get_axis()});

return std::make_shared<v0::Squeeze>(result, axis_to_remove);
return {std::make_shared<v0::Squeeze>(result, axis_to_remove)};
}

return result;
return {result};
}
} // namespace utils
} // namespace onnx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ class ArgMinMaxFactory {

/// \brief Creates ArgMax ONNX operation.
/// \return Sub-graph representing ArgMax op.
std::shared_ptr<ov::Node> make_arg_max() const;
ov::Output<ov::Node> make_arg_max() const;

/// \brief Creates ArgMin ONNX operation.
/// \return Sub-graph representing ArgMin op.
std::shared_ptr<ov::Node> make_arg_min() const;
ov::Output<ov::Node> make_arg_min() const;

private:
std::shared_ptr<ov::Node> make_topk_subgraph(ov::op::v11::TopK::Mode mode) const;
ov::Output<ov::Node> make_topk_subgraph(ov::op::v11::TopK::Mode mode) const;

const std::int64_t m_keep_dims;
ov::Output<ov::Node> m_input_node;
Expand Down
12 changes: 6 additions & 6 deletions src/frontends/onnx/tests/onnx_import.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2744,16 +2744,16 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_argmax_int32) {
auto model = convert_model("argmax_int32.onnx");

auto test_case = ov::test::TestCase(model, s_device);
test_case.add_input<std::int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
test_case.add_expected_output<std::int64_t>({1, 1, 1, 1, 1, 1});
test_case.add_input<std::int32_t>({3, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
test_case.add_expected_output<std::int64_t>({0, 1, 1, 1, 1, 1});
test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_model_argmin_int32) {
auto model = convert_model("argmin_int32.onnx");

auto test_case = ov::test::TestCase(model, s_device);
test_case.add_input<std::int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
test_case.add_input<std::int32_t>({2, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
test_case.add_expected_output<std::int64_t>({0, 0, 0, 0});
test_case.run();
}
Expand All @@ -2762,7 +2762,7 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_argmax_float) {
auto model = convert_model("argmax_float.onnx");

auto test_case = ov::test::TestCase(model, s_device);
test_case.add_input<float>({4.f, 0.1f, 2.f, 3.f, -3.f, 1.f, -0.9f, 0.f, 1.f, 2.f, 3.f, 0.f});
test_case.add_input<float>({4.f, 0.1f, 2.f, 4.f, -3.f, 1.f, -0.9f, 0.f, 1.f, 2.f, 3.f, 0.f});
test_case.add_expected_output<std::int64_t>({0, 3, 0});
test_case.run();
}
Expand All @@ -2771,8 +2771,8 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_argmin_float) {
auto model = convert_model("argmin_float.onnx");

auto test_case = ov::test::TestCase(model, s_device);
test_case.add_input<float>({4.f, 0.1f, 2.f, 3.f, -3.f, 1.f, -0.9f, 0.f, 1.f, 2.f, 3.f, 0.f});
test_case.add_expected_output<std::int64_t>({1, 1, 0, 2});
test_case.add_input<float>({0.1f, 0.1f, 2.f, 3.f, -3.f, 1.f, -0.9f, 0.f, 1.f, 2.f, 3.f, 0.f});
test_case.add_expected_output<std::int64_t>({0, 1, 0, 2});
test_case.run();
}

Expand Down
Loading