Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
732f29a
Create IExecutionProvider wrapper for OrtEp
adrianlizarraga May 28, 2025
9c39966
Add OrtEp::GetCapability and stub implementation in dummy ep
adrianlizarraga May 28, 2025
db2cda3
Simplify GetCapability for OrtEp (no allocator). Make OrtGraph compat…
adrianlizarraga May 29, 2025
93900ba
Add API OrtGraph_GetNodes()
adrianlizarraga May 29, 2025
5676355
Derive from OrtNode and OrtGraph instead of using variant (reduce str…
adrianlizarraga May 29, 2025
353fdb8
Separate model-editor and ep-api graph types into different files
adrianlizarraga May 29, 2025
62eb132
Revert variable name change to reduce diff
adrianlizarraga May 29, 2025
38b5398
Add some comments
adrianlizarraga May 29, 2025
09cea7a
Add unit test that adds the example plugin ep to a session.
adrianlizarraga May 30, 2025
a659a80
Merge branch 'main' into adrianl/ep-abi
adrianlizarraga May 30, 2025
9f45c18
Add C APIs to get a node's name and op_type
adrianlizarraga May 30, 2025
7cac7d1
Update error message when user passes in the wrong kind of OrtNode/Or…
adrianlizarraga May 30, 2025
5b38f4b
Add ToInternal() methods to convert from OrtNode to its derived classes
adrianlizarraga May 30, 2025
8abe39e
clean up
adrianlizarraga May 30, 2025
9abc3ff
Rename AddSupportedNodes api func. Dont force eps to provide a subgra…
adrianlizarraga May 30, 2025
df0636d
Add OrtEp::Compile()
adrianlizarraga May 31, 2025
177828a
Fill out Compile() wrapper function more
adrianlizarraga May 31, 2025
0206de9
Fix annotation
adrianlizarraga May 31, 2025
a3ab415
Update NodeComputeInfo handling
adrianlizarraga May 31, 2025
59b5367
Merge branch 'main' into adrianl/ep-abi
adrianlizarraga May 31, 2025
77b3037
Finish making Mul work in example plugin EP
adrianlizarraga May 31, 2025
1b616d8
rename file
adrianlizarraga May 31, 2025
217aca1
Apply suggestions from code review
adrianlizarraga May 31, 2025
2adb618
Update onnxruntime/test/autoep/library/example_plugin_ep.cc
adrianlizarraga May 31, 2025
e144c80
hide reinterpret_cast with a ToInternal call
adrianlizarraga May 31, 2025
9031bd5
Use only one enum to differentiate between the ModelEditorApi and EpA…
adrianlizarraga Jun 2, 2025
37dd688
Add many more graph apis for traversing edges
adrianlizarraga Jun 2, 2025
ee354a4
Move graph apis to OrtApi struct
adrianlizarraga Jun 2, 2025
3a99ff2
ifdef out stuff for minimal build
adrianlizarraga Jun 2, 2025
d969d11
More ifdef for minimal builds
adrianlizarraga Jun 2, 2025
deb145d
Rename API to GetValueInfoUses to account for multiple OrtNode instan…
adrianlizarraga Jun 3, 2025
ac95834
Fix call to GetUses
adrianlizarraga Jun 3, 2025
dfc4f48
Make returning the producer's output index optional
adrianlizarraga Jun 3, 2025
c58ec9d
Return error status from GetValueInfoTypeInfo if type info is not valid
adrianlizarraga Jun 3, 2025
d7dc6db
Add a test executable for testing use of OrtGraph APIs
adrianlizarraga Jun 3, 2025
c1ada2f
Add nv lib to test; clean up
adrianlizarraga Jun 4, 2025
4443f2e
Cleanup: add includes, update some comments
adrianlizarraga Jun 4, 2025
c219e56
Merge main and fix conflicts
adrianlizarraga Jun 4, 2025
2cf5e59
Add fused nodes back into OrtEp::Compile()
adrianlizarraga Jun 4, 2025
66f78e5
Remove 'Get' from APIs that look like field/property getters
adrianlizarraga Jun 4, 2025
8641558
Add remaining OrtGraph apis to get inputs/outputs
adrianlizarraga Jun 4, 2025
b34b83c
Set missing optional node outputs to a null OrtValueInfo
adrianlizarraga Jun 4, 2025
ef49e54
Rename some GetValueInfo* apis. Add ability to specify a single suppo…
adrianlizarraga Jun 5, 2025
a55988d
Fix GetValueConsumers() to return a negative input index for consumer…
adrianlizarraga Jun 5, 2025
66044f2
Add Node_GetSinceVersion C API
adrianlizarraga Jun 5, 2025
05166d2
Add C APIs: Node_GetNumSubgraphs, Node_GetSubgraphs, Node_GetParentGraph
adrianlizarraga Jun 6, 2025
55e43e0
initialize var to fix warning/error on linux. set input index to -1 f…
adrianlizarraga Jun 6, 2025
c1e63cd
Remove OrtHardwareDevice passing from GetCapability
adrianlizarraga Jun 6, 2025
bc445f9
Add C APIs: Node_GetNumImplicitInputs(), Node_GetImplicitInputs()
adrianlizarraga Jun 6, 2025
dd1efdb
minor clean up on comments
adrianlizarraga Jun 6, 2025
89e50eb
more efficient node_index_to_ep_node map
adrianlizarraga Jun 6, 2025
4282be6
Merge branch 'main' into adrianl/ep-abi
adrianlizarraga Jun 6, 2025
e755b84
Add C API: Graph_GetParentNode(). Split up test_ep_graph.cc into anot…
adrianlizarraga Jun 6, 2025
28eec29
Add C APIs to query whether an OrtValueInfo is a graph input, a graph…
adrianlizarraga Jun 9, 2025
4ca1f37
Merge branch 'main' into adrianl/ep-abi
adrianlizarraga Jun 9, 2025
4348e73
Address review comments and update graph ir tests
adrianlizarraga Jun 9, 2025
a96fc71
Merge branch 'main' into adrianl/ep-abi
adrianlizarraga Jun 9, 2025
e2a5b8e
Address more review comments
adrianlizarraga Jun 9, 2025
a9d8956
Add C API Node_Id() and sample Kahns topological sort that uses only …
adrianlizarraga Jun 9, 2025
e03ec4f
Merge branch 'main' into adrianl/ep-abi
adrianlizarraga Jun 12, 2025
0d36505
Split up files
adrianlizarraga Jun 12, 2025
c6aaf40
Address review comments: factory creation in loop, delete OrtStatus
adrianlizarraga Jun 12, 2025
9d0e2ae
Delete OrtStatus in provider_policy_context.cc
adrianlizarraga Jun 12, 2025
3927442
Remove intermediate variable
adrianlizarraga Jun 12, 2025
410e7df
Add C APIs to get initializer OrtValues
adrianlizarraga Jun 13, 2025
edb3b2d
Allow getting OrtValue for initializer defined in an outer scope
adrianlizarraga Jun 13, 2025
35a31a0
Remove 'order' parameter from Graph_GetNodes()
adrianlizarraga Jun 13, 2025
f7f4043
Merge main and fix conflicts
adrianlizarraga Jun 13, 2025
4772ee8
Fix warnings as errors
adrianlizarraga Jun 13, 2025
be4e86c
Add new c api header (for eps) to cmakes list of headers
adrianlizarraga Jun 13, 2025
3f77e03
Graph_GetNodes() directly return internal order (already sorted); Mak…
adrianlizarraga Jun 13, 2025
5c4af97
Python unit test that runs simple model with EP plugin
adrianlizarraga Jun 14, 2025
0011333
Update include/onnxruntime/core/session/onnxruntime_c_ep_api.h
adrianlizarraga Jun 14, 2025
ea41d12
Update include/onnxruntime/core/session/onnxruntime_c_ep_api.h
adrianlizarraga Jun 14, 2025
52d0332
Use GraphViewer's parent node/subgraph information to get initializer…
adrianlizarraga Jun 15, 2025
d10335e
Address many review comments
adrianlizarraga Jun 16, 2025
a446ca0
Merge branch 'adrianl/ep-abi' of github.com:microsoft/onnxruntime int…
adrianlizarraga Jun 16, 2025
f078b93
merge main and fix conflicts
adrianlizarraga Jun 16, 2025
9a993c5
Apply suggestions from code review
adrianlizarraga Jun 16, 2025
079add6
Add documentation for OrtConstPointerArray API functions
adrianlizarraga Jun 16, 2025
32de043
Clean up
adrianlizarraga Jun 16, 2025
ff85323
Apply suggestions from code review
adrianlizarraga Jun 16, 2025
93c6769
Add unit test for a graph with 3 layers of nested subgraphs
adrianlizarraga Jun 17, 2025
cd785b9
Start stubbing out IExecutionProvider::GetEpContextNodes()
adrianlizarraga Jun 17, 2025
e072a2d
Improve OrtArrayOfConstObjects to allow api user to create
adrianlizarraga Jun 18, 2025
c40baba
Clean up
adrianlizarraga Jun 18, 2025
74d3b4e
Add ArrayOfConstObjects_GetConstData
adrianlizarraga Jun 18, 2025
bae384d
Apply suggestions from code review
adrianlizarraga Jun 18, 2025
981726f
Adjust naming and description of OrtArrayOfConstObjects
adrianlizarraga Jun 18, 2025
910cd5e
wording: object vs element
adrianlizarraga Jun 18, 2025
ba3f356
Merge branch 'adrianl/ep-abi' into adrianl/ep-abi-ep-context-nodes
adrianlizarraga Jun 18, 2025
7950af3
Add code to Example Plugin EP that creates EPContext nodes and return…
adrianlizarraga Jun 19, 2025
8449070
Merge main and fix conflicts
adrianlizarraga Jun 19, 2025
b4f9f97
Encounter block with current approach
adrianlizarraga Jun 20, 2025
d29285b
Get the plugin EP to generate an EPContext model and added unit test
adrianlizarraga Jun 20, 2025
4649264
Update comments
adrianlizarraga Jun 20, 2025
459c4b5
Add missing documentation comment to the public API function
adrianlizarraga Jun 20, 2025
0aac40a
Clarify that EPContext nodes are constant
adrianlizarraga Jun 20, 2025
aa26dbd
Merge branch 'main' into adrianl/ep-abi-ep-context-nodes
adrianlizarraga Jun 20, 2025
cbd0a10
More editing of the public documentation
adrianlizarraga Jun 20, 2025
af1a312
Example EP code was getting too long. Separate utility functions into…
adrianlizarraga Jun 20, 2025
e3dff3e
Version 2 of GetEpContextNodes
adrianlizarraga Jun 20, 2025
9a127e7
Use correct SAL2 tag
adrianlizarraga Jun 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1828,6 +1828,8 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND
NOT onnxruntime_MINIMAL_BUILD)
onnxruntime_add_shared_library_module(example_plugin_ep
${TEST_SRC_DIR}/autoep/library/example_plugin_ep_utils.h
${TEST_SRC_DIR}/autoep/library/example_plugin_ep_utils.cc
${TEST_SRC_DIR}/autoep/library/example_plugin_ep.cc)
target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session)
target_link_libraries(example_plugin_ep PRIVATE onnxruntime)
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3678,7 +3678,7 @@ struct OrtApi {
*
* \param[in] name Name of the attribute
* \param[in] data Data content of the attribute
* \param[in] len Number of bytes stored in data
* \param[in] len Number of elements if data represents an array (e.g., ORT_OP_ATTR_INTS). Otherwise, set to 1.
* \param[in] type Data type
* \param[out] op_attr Attribute that has been created, which must be released by OrtApi::ReleaseOpAttr
*
Expand Down
23 changes: 19 additions & 4 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ struct OrtEp {
/** \brief Compile OrtGraph instances assigned to the OrtEp. Implementer must set a OrtNodeComputeInfo instance
* for each OrtGraph in order to define its computation function.
*
* If the session is configured to generate a pre-compiled model, the execution provider must return EPContext nodes,
* as OrtNode instances, that ONNX Runtime uses to create a pre-compiled model, known as an "EPContext model".
* An EPContext model contains EPContext nodes. Each EPContext node encapsulates the pre-compiled binary data for a
* OrtGraph compiled for a specific execution provider. For more details about the EPContext design, refer to:
* \htmlonly
* <a href="https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html">EPContext design document.</a>
* \endhtmlonly
*
* \param[in] this_ptr The OrtEp instance.
* \param[in] graphs Array of `count` OrtGraph instances to compile. Each graph contains only the nodes for
* which the execution provider indicated support. Nested subgraphs contained by a
Expand All @@ -190,9 +198,15 @@ struct OrtEp {
* Each fused node is an OrtNode initialized with the intended fused node name and
* input/output information.
* \param[in] count The number of OrtGraph instances to compile.
* \param[inout] node_compute_infos Array of `count` OrtNodeComputeInfo instances that define each OrtGraph instance's
* computation function. The implementer allocates the OrtNodeComputeInfo instances.
* ORT calls ReleaseNodeComputeInfos() to release multiple instances in a batch.
* \param[out] node_compute_infos Array of `count` OrtNodeComputeInfo instances that define each OrtGraph instance's
* computation function. The implementer allocates the OrtNodeComputeInfo instances.
* ORT calls ReleaseNodeComputeInfos() to release multiple instances in a batch.
* \param[out] ep_context_nodes Output array of `count` OrtNode instances, each representing an EPContext
* node for a compiled OrtGraph. The execution provider must use
* OrtModelEditorApi::CreateNode to create the OrtNode instances. ONNX Runtime takes
* ownership of the OrtNode instances, so the execution provider must NOT call
* OrtApi::ReleaseNode. Should be ignored if the session is not configured to generate an
* EPContext model.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
Expand All @@ -204,7 +218,8 @@ struct OrtEp {
*/
OrtStatus*(ORT_API_CALL* Compile)(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs,
_In_ const OrtNode** fused_nodes, _In_ size_t count,
_Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos);
_Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos,
_Out_writes_(count) OrtNode** ep_context_nodes);

/** \brief Release OrtNodeComputeInfo instances.
*
Expand Down
130 changes: 125 additions & 5 deletions onnxruntime/core/session/ep_plugin_provider_interfaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
#include <unordered_set>
#include <utility>
#include <vector>
#include "core/framework/abi_pointer_array.h"
#include "core/framework/compute_capability.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/model_metadef_id_generator.h"
#include "core/graph/ep_api_types.h"
#include "core/session/ort_apis.h"
#include "core/graph/model_editor_api_types.h"
#include "core/session/abi_devices.h"
#include "core/session/abi_ep_types.h"
#include "core/session/abi_logger.h"
#include "core/session/abi_session_options_impl.h"
#include "core/session/allocator_adapters.h"
#include "core/session/ort_apis.h"
#include "core/providers/partitioning_utils.h"

namespace onnxruntime {
Expand Down Expand Up @@ -48,7 +51,8 @@ PluginExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_
ORT_THROW("Error creating execution provider: ", status.ToString());
}

auto ep_wrapper = std::make_unique<PluginExecutionProvider>(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)));
auto ep_wrapper = std::make_unique<PluginExecutionProvider>(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)),
session_options);
ep_wrapper->SetLogger(session_logger.ToInternal());

return ep_wrapper;
Expand Down Expand Up @@ -80,9 +84,10 @@ struct PluginEpMetaDefNameFunctor {
// PluginExecutionProvider
//

PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep)
PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options)
: IExecutionProvider(ep->GetName(ep.get()), OrtDevice()), // TODO: What to do about OrtDevice for plugins?
ort_ep_(std::move(ep)) {
generate_ep_ctx_model_ = session_options.value.GetEpContextGenerationOptions().enable;
}

PluginExecutionProvider::~PluginExecutionProvider() {
Expand Down Expand Up @@ -185,6 +190,87 @@ Status PluginExecutionProvider::FusedNodeState::AddFusedNode(const Node& fused_n
return Status::OK();
}

/// <summary>
/// Converts the EPContext nodes provided by the plugin EP (OrtNode instances) to onnxruntime::Node instances.
/// Note that the EP plugin uses the model editor API to create the OrtNode instances.
/// </summary>
/// <param name="ep_name">Name of the plugin EP.</param>
/// <param name="plugin_ep_context_nodes">EPContext nodes provided by the plugin EP.</param>
/// <param name="result_nodes">Output parameter set to the resulting array of EPContext nodes.</param>
/// <param name="result_node_args">Output parameter that stores the NodeArgs used by the EPContext nodes.</param>
/// <returns>A status indicating success or an error.</returns>
static Status ConvertEpContextNodes(const std::string& ep_name, const std::vector<OrtNode*> plugin_ep_context_nodes,
/*out*/ std::vector<std::unique_ptr<Node>>& result_nodes,
/*out*/ std::vector<std::unique_ptr<NodeArg>>& result_node_args) {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
if (plugin_ep_context_nodes.empty()) {
return Status::OK(); // No EPContext nodes.
}

std::vector<std::unique_ptr<Node>> ep_context_nodes_holder;
std::vector<std::unique_ptr<NodeArg>> ep_context_node_args_holder;

ep_context_nodes_holder.reserve(plugin_ep_context_nodes.size());

for (const OrtNode* ort_node : plugin_ep_context_nodes) {
ORT_RETURN_IF_NOT(ort_node != nullptr, ep_name, ": OrtEp::Compile() returned a NULL EPContext node.");

const ModelEditorNode* editor_node = ModelEditorNode::ToInternal(ort_node);
ORT_RETURN_IF_NOT(editor_node != nullptr, ep_name, ": OrtEp::Compile() returned OrtNode objects ",
"that were not created with OrtModelEditorApi.");

// Create NodeArg for each input/output.
std::vector<NodeArg*> input_node_args;
std::vector<NodeArg*> output_node_args;

input_node_args.reserve(editor_node->input_names.size());
output_node_args.reserve(editor_node->output_names.size());

for (const std::string& input_name : editor_node->input_names) {
auto node_arg = std::make_unique<NodeArg>(input_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type.
input_node_args.push_back(node_arg.get());
ep_context_node_args_holder.push_back(std::move(node_arg));
}

for (const std::string& output_name : editor_node->output_names) {
auto node_arg = std::make_unique<NodeArg>(output_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type.
output_node_args.push_back(node_arg.get());
ep_context_node_args_holder.push_back(std::move(node_arg));
}

// Create a name -> attribute map.
NodeAttributes attributes;
attributes.reserve(editor_node->attributes.size());

for (const ONNX_NAMESPACE::AttributeProto& attr : editor_node->attributes) {
attributes.emplace(attr.name(), attr);
}

// Create Node
auto internal_node = std::make_unique<Node>(editor_node->node_name,
editor_node->operator_name,
"EPContext node for " + ep_name,
input_node_args,
output_node_args,
&attributes,
editor_node->domain_name);

ep_context_nodes_holder.push_back(std::move(internal_node));
}

result_nodes = std::move(ep_context_nodes_holder);
result_node_args = std::move(ep_context_node_args_holder);

return Status::OK();
#else
ORT_UNUSED_PARAMETER(ep_name);
ORT_UNUSED_PARAMETER(plugin_ep_context_nodes);
ORT_UNUSED_PARAMETER(result_nodes);
ORT_UNUSED_PARAMETER(result_node_args);
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Creating EPContext models is not supported in this build");
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
}

common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_infos) {
const logging::Logger* logger = GetLogger();
Expand Down Expand Up @@ -220,8 +306,21 @@ common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGr
api_fused_nodes.push_back(ep_fused_node->ToExternal());
}

ORT_RETURN_IF_ERROR(ToStatusAndRelease(ort_ep_->Compile(ort_ep_.get(), api_graphs.data(), api_fused_nodes.data(),
num_graphs, api_node_compute_infos.data())));
// Provide an output buffer for the plugin EP to store EPContext nodes if it needs to (i.e., enabled in session options).
std::vector<std::unique_ptr<OrtNode, decltype(&OrtApis::ReleaseNode)>> plugin_ep_context_nodes_holder;
std::vector<OrtNode*> plugin_ep_context_nodes;
plugin_ep_context_nodes_holder.reserve(num_graphs);
plugin_ep_context_nodes.resize(num_graphs, nullptr);

Status compile_status = ToStatusAndRelease(ort_ep_->Compile(ort_ep_.get(), api_graphs.data(), api_fused_nodes.data(),
num_graphs, api_node_compute_infos.data(),
plugin_ep_context_nodes.data()));

// Store any EPContext nodes provided by the plugin EP in std::unique_ptr so that they are always properly released.
for (OrtNode* ort_node : plugin_ep_context_nodes) {
auto unique_ort_node = std::unique_ptr<OrtNode, decltype(&OrtApis::ReleaseNode)>(ort_node, OrtApis::ReleaseNode);
plugin_ep_context_nodes_holder.push_back(std::move(unique_ort_node));
}

// Save OrtNodeComputeInfo created by OrtEp instance. They're freed when this IExecutionProvider
// is destroyed.
Expand All @@ -231,6 +330,8 @@ common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGr
}
}

ORT_RETURN_IF_ERROR(compile_status);

// Initialize node_compute_infos as wrappers to api_node_compute_infos.
for (size_t i = 0; i < num_graphs; i++) {
OrtNodeComputeInfo* api_node_compute_info = api_node_compute_infos[i];
Expand Down Expand Up @@ -268,6 +369,25 @@ common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGr
node_compute_infos.push_back(std::move(compute_info));
}

// Convert the EPContext nodes provided by the plugin EP into onnxruntime::Node instances.
// We store the converted Node and NodeArg instances as members to ensure they can be returned to the ORT graph
// partitioner via a call to IExecutionProvider::GetEpContextNodes().
if (generate_ep_ctx_model_) {
ORT_RETURN_IF_ERROR(ConvertEpContextNodes(Type(), plugin_ep_context_nodes,
/*out*/ ep_context_nodes_, /*out*/ ep_context_node_args_));
}

return Status::OK();
}

const InlinedVector<const Node*> PluginExecutionProvider::GetEpContextNodes() const {
InlinedVector<const Node*> result;

for (const std::unique_ptr<Node>& node : ep_context_nodes_) {
result.push_back(node.get());
}

return result;
}

} // namespace onnxruntime
12 changes: 11 additions & 1 deletion onnxruntime/core/session/ep_plugin_provider_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
namespace onnxruntime {
struct EpNode;
struct EpValueInfo;
class NodeArg;

/// <summary>
/// IExecutionProviderFactory that wraps a OrtEpFactory. Required for SessionOptionsAppendExecutionProvider_V2.
Expand Down Expand Up @@ -59,7 +60,7 @@ using UniqueOrtEp = std::unique_ptr<OrtEp, OrtEpDeleter>;
/// </summary>
class PluginExecutionProvider : public IExecutionProvider {
public:
explicit PluginExecutionProvider(UniqueOrtEp ep);
explicit PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options);
~PluginExecutionProvider();

std::vector<std::unique_ptr<ComputeCapability>>
Expand All @@ -71,6 +72,8 @@ class PluginExecutionProvider : public IExecutionProvider {
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) override;

const InlinedVector<const Node*> GetEpContextNodes() const override;

private:
struct FusedNodeState {
FusedNodeState() = default;
Expand All @@ -83,12 +86,19 @@ class PluginExecutionProvider : public IExecutionProvider {
};

UniqueOrtEp ort_ep_;
bool generate_ep_ctx_model_ = false;
std::vector<OrtNodeComputeInfo*> api_node_compute_infos_;

// Fused nodes have to be valid throughout model inference because they may be cached in NodeComputeInfo instances.
// For each fused node, the Compile() function creates EpNode and EpValueInfo instances on the heap,
// which are then passed to the underlying OrtEp instance. This class stores this "fused node state"
// so that it is not destroyed until the EP itself is destroyed.
std::vector<FusedNodeState> fused_node_states_;

// Stores the EPContext Nodes created from the OrtNode instances returned by the underlying plugin EP.
// Need to store both the Node and NodeArg instances so that they are available when the GraphPartitioner
// calls IExecutionProvider::GetEpContextNodes().
std::vector<std::unique_ptr<Node>> ep_context_nodes_;
std::vector<std::unique_ptr<NodeArg>> ep_context_node_args_;
};
} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/core/session/provider_policy_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ Status ProviderPolicyContext::CreateExecutionProvider(const Environment& env, Or
ORT_RETURN_IF_ERROR(ToStatusAndRelease(info.ep_factory->CreateEp(info.ep_factory, info.devices.data(),
info.ep_metadata.data(), info.devices.size(),
&options, &logger, &api_ep)));
ep = std::make_unique<PluginExecutionProvider>(UniqueOrtEp(api_ep, OrtEpDeleter(*info.ep_factory)));
ep = std::make_unique<PluginExecutionProvider>(UniqueOrtEp(api_ep, OrtEpDeleter(*info.ep_factory)), options);
}

return Status::OK();
Expand Down
Loading
Loading