Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -795,9 +795,6 @@ typedef struct OrtCompileApi OrtCompileApi;
struct OrtEpApi;
typedef struct OrtEpApi OrtEpApi;

struct OrtNodeComputeInfo;
typedef struct OrtNodeComputeInfo OrtNodeComputeInfo;

/** \brief The helper interface to get the right version of OrtApi
*
* Get a pointer to this structure through ::OrtGetApiBase
Expand Down
48 changes: 42 additions & 6 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,44 @@ ORT_RUNTIME_CLASS(EpFactory);
ORT_RUNTIME_CLASS(EpGraphSupportInfo);
ORT_RUNTIME_CLASS(NodeComputeContext);

struct OrtNodeFusionOptions;
typedef struct OrtNodeFusionOptions OrtNodeFusionOptions;

struct OrtNodeComputeInfo;
typedef struct OrtNodeComputeInfo OrtNodeComputeInfo;

/**
* \brief The OrtNodeFusionOptions struct specifies options for fusing nodes supported by an execution provider.
*
* Refer to OrtEpApi::EpGraphSupportInfo_AddNodesToFuse.
*
* \since Version 1.23.
*/
struct OrtNodeFusionOptions {
/** \brief The ONNX Runtime version the OrtNodeFusionOptions was compiled with.
*
* Implementation should set to ORT_API_VERSION.
* ORT will use this to ensure it does not use members that were not available when the EP library was compiled.
*
* \since Version 1.23.
*/
uint32_t ort_version_supported;

/** \brief If set to true, specify that the execution provider does not require ONNX Runtime to provide constant
* initializers as inputs to the fused node during model inference. This is used when the execution
* provider saves a copy of constant initializers, and allows ONNX Runtime to release constant initializers that
* are not used by any execution provider.
*
* If not specified, defaults to false. That is, ONNX Runtime provides constant initializers as inputs to
* the fused node by default.
*
* \since Version 1.23.
*/
bool drop_constant_initializers;

// const OrtNode* fused_node_schema;
};

/**
* \brief The OrtNodeComputeInfo struct provides functions that an OrtEp implements to specify the compute
* function for a compiled OrtGraph instance.
Expand All @@ -21,7 +59,7 @@ struct OrtNodeComputeInfo {
/** \brief The ONNX Runtime version the OrtNodeComputeInfo was compiled with.
*
* Implementation should set to ORT_API_VERSION.
* ORT will use this to ensure it does not call functions that were not available when the library was compiled.
* ORT will use this to ensure it does not call functions that were not available when the EP library was compiled.
*
* \since Version 1.23.
*/
Expand Down Expand Up @@ -87,9 +125,6 @@ struct OrtEpApi {
ORT_CLASS_RELEASE(EpDevice);

/** \brief Specify nodes that are supported by an OrtEp and should be fused into one node.
*
* IMPORTANT: This is not the final version of this API function. This is currently experimental but will
* be stabilized by the ONNX Runtime 1.23 release.
*
* Because the nodes will be fused into one "fused node", there must not exist an unsupported node in
* a path between two of the provided nodes. Otherwise, the graph will become invalid.
Expand All @@ -100,14 +135,15 @@ struct OrtEpApi {
* \param[in] graph_support_info OrtEpGraphSupportInfo instance to which to add the supported nodes.
* \param[in] nodes Array of nodes supported by the EP that should be fused/compiled.
* \param[in] num_nodes The number of supported nodes.
* \param[in] node_fusion_options Optional node fusion options. Ignored if set to NULL.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes
/*, OrtFusedNodeSchema* optional_fused_node_schema, OrtNodesToOptimizeInfo* nodes_to_opt*/);
_In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes,
_In_opt_ const OrtNodeFusionOptions* node_fusion_options);

/** \brief Specify a node that is supported by an OrtEp and should be run with a registered EP kernel.
*
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/core/session/abi_ep_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
#include "core/graph/ep_api_types.h"
#include "core/session/abi_devices.h"

onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span<const OrtNode* const> nodes) {
onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span<const OrtNode* const> nodes,
const OrtNodeFusionOptions* optional_fusion_options) {
std::vector<const onnxruntime::EpNode*> ep_nodes;
ep_nodes.reserve(nodes.size());

Expand All @@ -20,14 +21,14 @@ onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span<const OrtNod
ep_nodes.push_back(ep_node);
}

node_groupings.emplace_back(NodeGroupingKind::kFusedNode, std::move(ep_nodes));
node_groupings.emplace_back(NodeGroupingKind::kFusedNode, std::move(ep_nodes),
optional_fusion_options != nullptr ? *optional_fusion_options : OrtNodeFusionOptions{});
return onnxruntime::Status::OK();
}

onnxruntime::Status OrtEpGraphSupportInfo::AddSingleNode(const OrtNode* node) {
std::vector<const onnxruntime::EpNode*> ep_nodes;
ep_nodes.push_back(onnxruntime::EpNode::ToInternal(node));
node_groupings.emplace_back(NodeGroupingKind::kSingleAssignedNode, std::move(ep_nodes));

return onnxruntime::Status::OK();
}
9 changes: 6 additions & 3 deletions onnxruntime/core/session/abi_ep_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,19 @@ struct OrtEpGraphSupportInfo {

// A grouping of supported nodes that should be handled in a single ComputeCapability.
struct NodeGrouping {
NodeGrouping(NodeGroupingKind kind, std::vector<const onnxruntime::EpNode*>&& nodes)
: kind(kind), nodes(std::move(nodes)) {}
NodeGrouping(NodeGroupingKind kind, std::vector<const onnxruntime::EpNode*>&& nodes,
const OrtNodeFusionOptions& fusion_options = {})
: kind(kind), nodes(std::move(nodes)), fusion_options(fusion_options) {}

NodeGroupingKind kind = NodeGroupingKind::kInvalidGrouping;
std::vector<const onnxruntime::EpNode*> nodes;
OrtNodeFusionOptions fusion_options = {};
};

explicit OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph) : ort_graph(graph) {}

onnxruntime::Status AddNodesToFuse(gsl::span<const OrtNode* const> nodes);
onnxruntime::Status AddNodesToFuse(gsl::span<const OrtNode* const> nodes,
const OrtNodeFusionOptions* node_fusion_options = nullptr);
onnxruntime::Status AddSingleNode(const OrtNode* node);

const onnxruntime::EpGraph& ort_graph;
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/session/ep_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device) {
}

ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* ort_graph_support_info,
_In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes) {
_In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes,
_In_opt_ const OrtNodeFusionOptions* node_fusion_options) {
API_IMPL_BEGIN
if (ort_graph_support_info == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid OrtGraph instance");
Expand All @@ -55,7 +56,7 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInf
}

gsl::span<const OrtNode* const> nodes_span(nodes, nodes + num_nodes);
ORT_API_RETURN_IF_STATUS_NOT_OK(ort_graph_support_info->AddNodesToFuse(nodes_span));
ORT_API_RETURN_IF_STATUS_NOT_OK(ort_graph_support_info->AddNodesToFuse(nodes_span, node_fusion_options));
return nullptr;
API_IMPL_END
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/session/ep_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ ORT_API_STATUS_IMPL(CreateEpDevice, _In_ OrtEpFactory* ep_factory,
ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device);

ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes);
_In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes,
_In_opt_ const OrtNodeFusionOptions* node_fusion_options);
ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddSingleNode, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_ const OrtNode* node);
ORT_API(const char*, NodeComputeContext_NodeName, _In_ const OrtNodeComputeContext* context);
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/session/ep_plugin_provider_interfaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
} else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) {
std::unordered_set<const Node*> node_set;
node_set.reserve(node_grouping.nodes.size());

for (const EpNode* ep_node : node_grouping.nodes) {
node_set.insert(&ep_node->GetInternalNode());
}
Expand All @@ -151,7 +152,8 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
// unsupported nodes in any path between supported nodes.
std::vector<std::unique_ptr<ComputeCapability>> capabilities = utils::CreateSupportedPartitions(
graph_viewer, node_set, /*stop_ops*/ {}, PluginEpMetaDefNameFunctor(generator, graph_viewer, this->Type()),
this->Type(), this->Type(), /*node_unit_map*/ nullptr);
this->Type(), this->Type(), /*node_unit_map*/ nullptr,
node_grouping.fusion_options.drop_constant_initializers);

if (capabilities.size() > 1) {
LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() set nodes that cannot be fused together. "
Expand Down
Loading
Loading