Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ extern "C" {
ORT_RUNTIME_CLASS(Ep);
ORT_RUNTIME_CLASS(EpFactory);
ORT_RUNTIME_CLASS(EpGraphSupportInfo);
ORT_RUNTIME_CLASS(NodeFusionOptions);
ORT_RUNTIME_CLASS(NodeComputeContext);

/**
Expand Down Expand Up @@ -86,28 +87,64 @@ struct OrtEpApi {

ORT_CLASS_RELEASE(EpDevice);

/** \brief Specify nodes that are supported by an OrtEp and should be fused into one node.
*
* IMPORTANT: This is not the final version of this API function. This is currently experimental but will
* be stabilized by the ONNX Runtime 1.23 release.
/** \brief Create an OrtNodeFusionOptions instance for specifying options for fusing nodes supported by an
* execution provider.
*
* Because the nodes will be fused into one "fused node", there must not exist an unsupported node in
* a path between two of the provided nodes. Otherwise, the graph will become invalid.
*
* \param[in] nodes Array of nodes supported by the EP that should be fused.
* \param[in] num_nodes The number of nodes.
* \param[out] out Output parameter set to the OrtNodeFusionOptions instance that is created.
*
* \since Version 1.23.
*/
ORT_API2_STATUS(CreateNodeFusionOptions, _In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes,
_Outptr_ OrtNodeFusionOptions** out);

ORT_CLASS_RELEASE(NodeFusionOptions);

/** \brief Specify that the execution provider does not require ONNX Runtime to provide constant
* initializers as inputs to the fused node during model inference. This is used when the execution
* provider saves a copy of constant initializers, and allows ONNX Runtime to release constant initializers that
* are not used by any execution provider.
*
* If not specified, defaults to false. That is, ONNX Runtime provides constant initializers as inputs to
* the fused node by default.
*
* \param[in] options The OrtNodeFusionOptions instance.
* \param[in] drop True to indicate that the execution provider does not need ORT to provide constant initializers
* for this fused node.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(NodeFusionOptions_DropConstantInitializers, _In_ OrtNodeFusionOptions* options,
_In_ bool drop);

// ORT_API2_STATUS(NodeFusionOptions_SetMetaDefinition, _In_ OrtNodeFusionOptions* options,
// _In_ const OrtNode* fused_node_meta_def);

// ORT_API2_STATUS(NodeFusionOptions_SetOptimizationInfo, _In_ OrtNodeFusionOptions* options,
// _In_ const OrtOptimizationInfo* optimization_info);

/** \brief Specify nodes that are supported by an OrtEp and should be fused into one node.
*
* This function can be called multiple times. A subsequent call to this function will force the next set of
* nodes to be fused into a different node.
*
* \param[in] graph_support_info OrtEpGraphSupportInfo instance to which to add the supported nodes.
* \param[in] nodes Array of nodes supported by the EP that should be fused/compiled.
* \param[in] num_nodes The number of supported nodes.
* \param[in] node_fusion_options OrtNodeFusionOptions instance that specifies the nodes to fuse and other relevant
* options. ONNX Runtime takes ownership of this OrtNodeFusionOptions instance, so
* caller should NOT release the instance.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes
/*, OrtFusedNodeSchema* optional_fused_node_schema, OrtNodesToOptimizeInfo* nodes_to_opt*/);
_Inout_ OrtNodeFusionOptions* node_fusion_options);

/** \brief Specify a node that is supported by an OrtEp and should be run with a registered EP kernel.
*
Expand Down
27 changes: 11 additions & 16 deletions onnxruntime/core/session/abi_ep_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,19 @@
#include "core/graph/ep_api_types.h"
#include "core/session/abi_devices.h"

onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span<const OrtNode* const> nodes) {
std::vector<const onnxruntime::EpNode*> ep_nodes;
ep_nodes.reserve(nodes.size());

for (const OrtNode* node : nodes) {
const auto* ep_node = onnxruntime::EpNode::ToInternal(node);
ORT_RETURN_IF(ep_node == nullptr, "Invalid OrtNode variant for use in OrtEpApi.");
ep_nodes.push_back(ep_node);
}
const OrtNodeFusionOptions* OrtEpGraphSupportInfo::NodeGrouping::TryGetNodeFusionOptions() const {
return std::get_if<OrtNodeFusionOptions>(&variant_);
}

node_groupings.emplace_back(NodeGroupingKind::kFusedNode, std::move(ep_nodes));
return onnxruntime::Status::OK();
const onnxruntime::EpNode* OrtEpGraphSupportInfo::NodeGrouping::TryGetSingleNode() const {
const onnxruntime::EpNode* const* node_ptr = std::get_if<const onnxruntime::EpNode*>(&variant_);
return (node_ptr != nullptr) ? *node_ptr : nullptr;
}

onnxruntime::Status OrtEpGraphSupportInfo::AddSingleNode(const OrtNode* node) {
std::vector<const onnxruntime::EpNode*> ep_nodes;
ep_nodes.push_back(onnxruntime::EpNode::ToInternal(node));
node_groupings.emplace_back(NodeGroupingKind::kSingleAssignedNode, std::move(ep_nodes));
void OrtEpGraphSupportInfo::AddNodesToFuse(const OrtNodeFusionOptions& node_fusion_options) {
node_groupings.emplace_back(node_fusion_options);
}

return onnxruntime::Status::OK();
void OrtEpGraphSupportInfo::AddSingleNode(const onnxruntime::EpNode& node) {
node_groupings.emplace_back(&node);
}
27 changes: 15 additions & 12 deletions onnxruntime/core/session/abi_ep_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <gsl/gsl>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "core/common/inlined_containers_fwd.h"
Expand All @@ -17,30 +18,32 @@ struct EpGraph;
struct EpNode;
} // namespace onnxruntime

struct OrtNodeFusionOptions {
std::vector<const onnxruntime::EpNode*> nodes;
bool drop_constant_initializers = false;
};

/// <summary>
/// Class used specify the nodes an EP supports. An instance of this class is passed to OrtEp's
/// GetCapability() function. An OrtEp adds groups of supported nodes to the OrtEpGraphSupportInfo instance.
/// </summary>
struct OrtEpGraphSupportInfo {
enum class NodeGroupingKind {
kInvalidGrouping = 0,
kSingleAssignedNode,
kFusedNode,
};

// A grouping of supported nodes that should be handled in a single ComputeCapability.
struct NodeGrouping {
NodeGrouping(NodeGroupingKind kind, std::vector<const onnxruntime::EpNode*>&& nodes)
: kind(kind), nodes(std::move(nodes)) {}
explicit NodeGrouping(const onnxruntime::EpNode* single_node) : variant_(single_node) {}
explicit NodeGrouping(const OrtNodeFusionOptions& node_fusion_options) : variant_(node_fusion_options) {}

const OrtNodeFusionOptions* TryGetNodeFusionOptions() const;
const onnxruntime::EpNode* TryGetSingleNode() const;

NodeGroupingKind kind = NodeGroupingKind::kInvalidGrouping;
std::vector<const onnxruntime::EpNode*> nodes;
private:
std::variant<OrtNodeFusionOptions, const onnxruntime::EpNode*> variant_ = {};
};

explicit OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph) : ort_graph(graph) {}

onnxruntime::Status AddNodesToFuse(gsl::span<const OrtNode* const> nodes);
onnxruntime::Status AddSingleNode(const OrtNode* node);
void AddNodesToFuse(const OrtNodeFusionOptions& node_fusion_options);
void AddSingleNode(const onnxruntime::EpNode& node);

const onnxruntime::EpGraph& ort_graph;
std::vector<NodeGrouping> node_groupings;
Expand Down
68 changes: 59 additions & 9 deletions onnxruntime/core/session/ep_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,60 @@
delete device;
}

ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* ort_graph_support_info,
_In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes) {
ORT_API_STATUS_IMPL(CreateNodeFusionOptions, _In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes,
_Outptr_ OrtNodeFusionOptions** out) {
API_IMPL_BEGIN
if (ort_graph_support_info == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid OrtGraph instance");
if (nodes == nullptr || num_nodes == 0) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify one or more valid nodes.");
}

if (num_nodes == 0 || nodes == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of 1 or more supported nodes");
auto node_fusion_options = std::make_unique<OrtNodeFusionOptions>();
node_fusion_options->nodes.reserve(num_nodes);

for (size_t i = 0; i < num_nodes; ++i) {
const OrtNode* ort_node = nodes[i];

if (ort_node == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtNode instance is NULL.");
}

const auto* ep_node = onnxruntime::EpNode::ToInternal(ort_node);

if (ep_node == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Unexpected variant of OrtNode that is not compatible with OrtEpApi.");
}

node_fusion_options->nodes.push_back(ep_node);
}

gsl::span<const OrtNode* const> nodes_span(nodes, nodes + num_nodes);
ORT_API_RETURN_IF_STATUS_NOT_OK(ort_graph_support_info->AddNodesToFuse(nodes_span));
*out = node_fusion_options.release();
return nullptr;
API_IMPL_END
}

ORT_API(void, ReleaseNodeFusionOptions, _Frees_ptr_opt_ OrtNodeFusionOptions* options) {
delete options;
}

ORT_API_STATUS_IMPL(NodeFusionOptions_DropConstantInitializers, _In_ OrtNodeFusionOptions* options,
_In_ bool drop) {
API_IMPL_BEGIN
options->drop_constant_initializers = drop;
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* ort_graph_support_info,
_Inout_ OrtNodeFusionOptions* node_fusion_options) {
API_IMPL_BEGIN
std::unique_ptr<OrtNodeFusionOptions> owned_options(node_fusion_options); // Take ownership

Check warning on line 93 in onnxruntime/core/session/ep_api.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/ep_api.cc:93: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]

if (node_fusion_options == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'node_fusion_options' argument is NULL.");
}

ort_graph_support_info->AddNodesToFuse(*owned_options);
return nullptr;
API_IMPL_END
}
Expand All @@ -71,7 +112,13 @@
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null OrtNode");
}

ORT_API_RETURN_IF_STATUS_NOT_OK(ort_graph_support_info->AddSingleNode(node));
const auto* ep_node = onnxruntime::EpNode::ToInternal(node);

if (ep_node == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Unexpected variant of OrtNode is not compatible with OrtEpApi.");
}

ort_graph_support_info->AddSingleNode(*ep_node);
return nullptr;
API_IMPL_END
}
Expand All @@ -93,6 +140,9 @@
&OrtExecutionProviderApi::ReleaseEpDevice,
// End of Version 22 - DO NOT MODIFY ABOVE

&OrtExecutionProviderApi::CreateNodeFusionOptions,
&OrtExecutionProviderApi::ReleaseNodeFusionOptions,
&OrtExecutionProviderApi::NodeFusionOptions_DropConstantInitializers,
&OrtExecutionProviderApi::EpGraphSupportInfo_AddNodesToFuse,
&OrtExecutionProviderApi::EpGraphSupportInfo_AddSingleNode,
&OrtExecutionProviderApi::NodeComputeContext_NodeName,
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/core/session/ep_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ ORT_API_STATUS_IMPL(CreateEpDevice, _In_ OrtEpFactory* ep_factory,

ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device);

ORT_API_STATUS_IMPL(CreateNodeFusionOptions, _In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes,
_Outptr_ OrtNodeFusionOptions** out);
ORT_API(void, ReleaseNodeFusionOptions, _Frees_ptr_opt_ OrtNodeFusionOptions* options);
ORT_API_STATUS_IMPL(NodeFusionOptions_DropConstantInitializers, _In_ OrtNodeFusionOptions* options,
_In_ bool drop);
ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes);
_Inout_ OrtNodeFusionOptions* node_fusion_options);
ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddSingleNode, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_ const OrtNode* node);
ORT_API(const char*, NodeComputeContext_NodeName, _In_ const OrtNodeComputeContext* context);
Expand Down
18 changes: 8 additions & 10 deletions onnxruntime/core/session/ep_plugin_provider_interfaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,17 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie

// Create ComputeCapability instances from OrtEpGraphSupportInfo::NodeGrouping instances.
for (const OrtEpGraphSupportInfo::NodeGrouping& node_grouping : api_graph_support_info.node_groupings) {
if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kSingleAssignedNode) {
if (const EpNode* single_node = node_grouping.TryGetSingleNode(); single_node != nullptr) {
auto indexed_sub_graph = std::make_unique<IndexedSubGraph>();

indexed_sub_graph->nodes.push_back(node_grouping.nodes[0]->GetInternalNode().Index());
indexed_sub_graph->nodes.push_back(single_node->GetInternalNode().Index());
result.push_back(std::make_unique<ComputeCapability>(std::move(indexed_sub_graph)));
} else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) {
} else if (const OrtNodeFusionOptions* node_fusion_options = node_grouping.TryGetNodeFusionOptions();
node_fusion_options != nullptr) {
std::unordered_set<const Node*> node_set;
node_set.reserve(node_grouping.nodes.size());
for (const EpNode* ep_node : node_grouping.nodes) {
node_set.reserve(node_fusion_options->nodes.size());

for (const EpNode* ep_node : node_fusion_options->nodes) {
node_set.insert(&ep_node->GetInternalNode());
}

Expand All @@ -146,7 +148,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
// unsupported nodes in any path between supported nodes.
std::vector<std::unique_ptr<ComputeCapability>> capabilities = utils::CreateSupportedPartitions(
graph_viewer, node_set, /*stop_ops*/ {}, PluginEpMetaDefNameFunctor(generator, graph_viewer, this->Type()),
this->Type(), this->Type(), /*node_unit_map*/ nullptr);
this->Type(), this->Type(), /*node_unit_map*/ nullptr, node_fusion_options->drop_constant_initializers);

if (capabilities.size() > 1) {
LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() set nodes that cannot be fused together. "
Expand All @@ -167,10 +169,6 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
}));

result.push_back(std::move(capabilities[0]));
} else {
LOGS_DEFAULT(ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: "
<< static_cast<int>(node_grouping.kind);
return {};
}
}

Expand Down
Loading
Loading