Skip to content

Commit 6a05ba6

Browse files
chilo-msjnagi-intel
authored andcommitted
[EP ABI] Add Graph_GetGraphView API to get a OrtGraph from a subset of nodes (microsoft#25191)
Added an API that creates a sub-graph from a set of nodes in an OrtGraph. This API is needed in the GetCapability EP ABI porting when EP wants to check whether a 'sub-graph' of the graph is supported by the hardware backend.
1 parent ed287d6 commit 6a05ba6

File tree

10 files changed

+237
-3
lines changed

10 files changed

+237
-3
lines changed

include/onnxruntime/core/graph/graph.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -952,9 +952,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
952952
return const_cast<Graph*>(this)->GetNodeArg(name);
953953
}
954954

955-
// search this and up through any parent_graph_ instance for a NodeArg
955+
// Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding mutable NodeArg
956956
NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name);
957957

958+
// Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding const NodeArg
959+
const NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const;
960+
958961
/** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found.
959962
@param name The NodeArg name.
960963
@param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created.

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5748,6 +5748,24 @@ struct OrtApi {
57485748
*/
57495749
ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node);
57505750

5751+
/** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph.
5752+
*
5753+
* Note:
5754+
* The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference
5755+
* the same underlying graph.
5756+
*
5757+
* \param[in] src_graph The source OrtGraph instance.
5758+
* \param[in] nodes A subset of the nodes/OrtNodes in 'graph'.
5759+
* \param[in] num_nodes Number of nodes.
5760+
* \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph.
5761+
*
5762+
* \snippet{doc} snippets.dox OrtStatus Return Value
5763+
*
5764+
* \since Version 1.23.
5765+
*/
5766+
ORT_API2_STATUS(Graph_GetGraphView, _In_ const OrtGraph* src_graph, _In_ const OrtNode** nodes,
5767+
_In_ size_t num_nodes, _Outptr_ OrtGraph** dst_graph);
5768+
57515769
/// @}
57525770

57535771
/// \name OrtNode

onnxruntime/core/graph/ep_api_types.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,10 +505,34 @@ void EpGraph::IndexToEpNodeMap::SetEpNode(NodeIndex node_index, EpNode* ep_node)
505505
EpGraph::EpGraph(const GraphViewer& graph_viewer, PrivateTag)
506506
: OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(graph_viewer) {}
507507

508+
EpGraph::EpGraph(std::unique_ptr<GraphViewer> graph_viewer,
509+
std::unique_ptr<IndexedSubGraph> indexed_sub_graph,
510+
PrivateTag)
511+
: OrtGraph(OrtGraphIrApi::kEpApi),
512+
graph_viewer_(*graph_viewer.get()),
513+
owned_graph_viewer_(std::move(graph_viewer)),
514+
owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {}
515+
508516
// Static class function to create a std::unique_ptr<EpGraph>.
509517
Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result) {
510518
auto ep_graph = std::make_unique<EpGraph>(graph_viewer, PrivateTag{});
511519

520+
return CreateImpl(std::move(ep_graph), graph_viewer, result);
521+
}
522+
523+
// Static class function to create a std::unique_ptr<EpGraph>.
524+
Status EpGraph::Create(std::unique_ptr<GraphViewer> src_graph_viewer,
525+
std::unique_ptr<IndexedSubGraph> src_indexed_sub_graph,
526+
/*out*/ std::unique_ptr<EpGraph>& result) {
527+
auto& graph_viewer = *src_graph_viewer.get();
528+
auto ep_graph = std::make_unique<EpGraph>(std::move(src_graph_viewer),
529+
std::move(src_indexed_sub_graph),
530+
PrivateTag{});
531+
532+
return CreateImpl(std::move(ep_graph), graph_viewer, result);
533+
}
534+
535+
Status EpGraph::CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result) {
512536
AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance();
513537
std::unordered_map<std::string, std::unique_ptr<EpValueInfo>> value_infos_map;
514538

onnxruntime/core/graph/ep_api_types.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,15 +251,32 @@ struct EpGraph : public OrtGraph {
251251

252252
public:
253253
EpGraph(const GraphViewer& graph_viewer, PrivateTag);
254+
EpGraph(std::unique_ptr<GraphViewer> graph_viewer,
255+
std::unique_ptr<IndexedSubGraph> indexed_sub_graph,
256+
PrivateTag);
254257

255258
/// <summary>
256259
/// Creates an instance of EpGraph, which wraps a GraphViewer.
260+
/// This call is used when creating an EpGraph from a GraphViewer instance. The GraphViewer instance is not onwed by this EpGraph.
257261
/// </summary>
258262
/// <param name="graph_viewer"></param>
259263
/// <param name="result"></param>
260264
/// <returns></returns>
261265
static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result);
262266

267+
/// <summary>
268+
/// Creates an instance of EpGraph, which wraps a GraphViewer.
269+
/// This call is used when creating an EpGraph from a subset of nodes in another EpGraph.
270+
/// In this case, due to the implementation of OrtApis::Graph_GetGraphView, the new EpGraph instance
271+
/// must take ownership of both the GraphViewer and IndexedSubGraph.
272+
/// </summary>
273+
/// <param name="graph_viewer"></param>
274+
/// <param name="result"></param>
275+
/// <returns></returns>
276+
static Status Create(std::unique_ptr<GraphViewer> graph_viewer,
277+
std::unique_ptr<IndexedSubGraph> indexed_sub_graph,
278+
/*out*/ std::unique_ptr<EpGraph>& result);
279+
263280
// Defines ToExternal() and ToInternal() functions to convert between OrtGraph and EpGraph.
264281
DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, EpGraph, OrtGraphIrApi::kEpApi)
265282

@@ -331,9 +348,22 @@ struct EpGraph : public OrtGraph {
331348
const OrtValue* GetInitializerValue(std::string_view name) const;
332349

333350
private:
351+
/// <summary>
352+
/// The real implementation of creating an EpGraph instance.
353+
/// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly.
354+
/// </summary>
355+
/// <param name="ep_graph"></param>
356+
/// <param name="graph_viewer"></param>
357+
/// <param name="result"></param>
358+
/// <returns></returns>
359+
static Status CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result);
360+
334361
const GraphViewer& graph_viewer_;
335362
const EpNode* parent_node_ = nullptr;
336363

364+
std::unique_ptr<GraphViewer> owned_graph_viewer_ = nullptr;
365+
std::unique_ptr<IndexedSubGraph> owned_indexed_sub_graph_ = nullptr;
366+
337367
std::vector<std::unique_ptr<EpNode>> nodes_;
338368
IndexToEpNodeMap index_to_ep_node_;
339369

onnxruntime/core/graph/graph.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,10 @@ NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name
18181818
return node_arg;
18191819
}
18201820

1821+
const NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const {
1822+
return const_cast<Graph*>(this)->GetNodeArgIncludingParentGraphs(node_arg_name);
1823+
}
1824+
18211825
void Graph::ReverseDFSFrom(gsl::span<NodeIndex const> from,
18221826
const std::function<void(const Node*)>& enter,
18231827
const std::function<void(const Node*)>& leave,

onnxruntime/core/graph/graph_viewer.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,15 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
168168
filtered_node_inputs_including_initializers_.reserve(metadef->inputs.size());
169169

170170
for (const auto& input : metadef->inputs) {
171-
const auto* nodearg = graph.GetNodeArg(input);
171+
// NodeArgs from the current scope or any outer scopes should be handled correctly.
172+
//
173+
// There is an edge case where the model consists of a graph with subgraphs nested across three levels.
174+
// In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer).
175+
// When constructing a new GraphViewer for the second- and third-layer subgraphs,
176+
// the second-layer graph may not have the corresponding value_info for that first-layer input,
177+
// because the second-layer graph itself doesn't consume it.
178+
// Therefore, when working within the second-layer graph, we need to search outer scopes for the missing value_info.
179+
const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(input);
172180
ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Input not found:", input);
173181
filtered_node_inputs_including_initializers_.push_back(nodearg);
174182
if (!graph.IsInitializedTensor(input)) {
@@ -177,7 +185,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
177185
}
178186

179187
for (const auto& output : metadef->outputs) {
180-
const auto* nodearg = graph.GetNodeArg(output);
188+
const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(output);
181189
ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Output not found:", output);
182190
filtered_node_outputs_.push_back(nodearg);
183191
}

onnxruntime/core/session/onnxruntime_c_api.cc

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2714,6 +2714,91 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _O
27142714
API_IMPL_END
27152715
}
27162716

2717+
ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph,
2718+
_In_ const OrtNode** nodes,
2719+
_In_ size_t num_nodes,
2720+
_Outptr_ OrtGraph** dst_graph) {
2721+
API_IMPL_BEGIN
2722+
2723+
if (num_nodes == 0) {
2724+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_nodes' argument should be > 0");
2725+
}
2726+
2727+
const EpGraph* ep_graph = EpGraph::ToInternal(src_graph);
2728+
if (ep_graph == nullptr) {
2729+
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "src_graph is a ModelEditorGraph which doesn't support Graph_GetSubGraph.");
2730+
}
2731+
const Graph& graph = ep_graph->GetGraphViewer().GetGraph();
2732+
2733+
// Create a GraphViewer with filtered info
2734+
std::unique_ptr<IndexedSubGraph> indexed_sub_graph = std::make_unique<IndexedSubGraph>();
2735+
std::unique_ptr<IndexedSubGraph::MetaDef> metadef = std::make_unique<IndexedSubGraph::MetaDef>();
2736+
metadef->name = "sub_graph";
2737+
metadef->since_version = 1;
2738+
std::unordered_set<std::string> outputs;
2739+
std::unordered_set<const NodeArg*> initializers;
2740+
2741+
auto add_inputs = [&](ConstPointerContainer<std::vector<NodeArg*>> defs) {
2742+
for (const auto* def : defs) {
2743+
if (def->Exists()) {
2744+
// not the output of a previous node
2745+
if (outputs.count(def->Name()) == 0) {
2746+
metadef->inputs.push_back(def->Name());
2747+
} else {
2748+
// consumed by node so no longer subgraph output
2749+
// NOTE: Ignoring edge case where a node output is an overall graph output AND a node input
2750+
outputs.erase(def->Name());
2751+
}
2752+
2753+
if (graph.IsInitializedTensor(def->Name())) {
2754+
initializers.insert(def);
2755+
}
2756+
}
2757+
}
2758+
};
2759+
2760+
auto add_node = [&](const Node& node) {
2761+
indexed_sub_graph->nodes.push_back(node.Index());
2762+
add_inputs(node.InputDefs());
2763+
add_inputs(node.ImplicitInputDefs());
2764+
2765+
for (const auto* def : node.OutputDefs()) {
2766+
outputs.insert(def->Name());
2767+
}
2768+
};
2769+
2770+
// Add nodes
2771+
for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) {
2772+
const OrtNode* ort_node = nodes[node_idx];
2773+
const EpNode* ep_node = EpNode::ToInternal(ort_node);
2774+
if (ep_node == nullptr) {
2775+
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Graph_GetSubGraph.");
2776+
}
2777+
add_node(ep_node->GetInternalNode());
2778+
}
2779+
2780+
// Add initializers
2781+
for (auto& initializer : initializers) {
2782+
metadef->constant_initializers.push_back(initializer->Name());
2783+
}
2784+
2785+
// Add outputs
2786+
for (auto& output : outputs) {
2787+
metadef->outputs.push_back(output);
2788+
}
2789+
2790+
indexed_sub_graph->SetMetaDef(std::move(metadef));
2791+
auto graph_viewer = std::make_unique<GraphViewer>(graph, *indexed_sub_graph.get());
2792+
2793+
std::unique_ptr<EpGraph> result;
2794+
ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(graph_viewer), std::move(indexed_sub_graph), result));
2795+
2796+
*dst_graph = result.release();
2797+
2798+
return nullptr;
2799+
API_IMPL_END
2800+
}
2801+
27172802
//
27182803
// OrtNode
27192804
//
@@ -3629,6 +3714,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
36293714
&OrtApis::Graph_GetNumNodes,
36303715
&OrtApis::Graph_GetNodes,
36313716
&OrtApis::Graph_GetParentNode,
3717+
&OrtApis::Graph_GetGraphView,
36323718
&OrtApis::Node_GetId,
36333719
&OrtApis::Node_GetName,
36343720
&OrtApis::Node_GetOperatorType,

onnxruntime/core/session/ort_apis.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,8 @@ ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t*
649649
ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph,
650650
_Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes);
651651
ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node);
652+
ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes,
653+
_Outptr_ OrtGraph** subgraph);
652654

653655
// OrtNode
654656
ORT_API_STATUS_IMPL(Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id);

onnxruntime/test/ep_graph/test_ep_graph.cc

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
#include <gsl/gsl>
88
#include <memory>
99
#include <vector>
10+
#include <fstream>
1011

1112
#include "core/common/common.h"
1213
#include "core/framework/tensorprotoutils.h"
1314
#include "core/framework/tensor_type_and_shape.h"
1415
#include "core/framework/onnxruntime_typeinfo.h"
1516
#include "core/session/onnxruntime_cxx_api.h"
17+
#include "core/graph/ep_api_types.h"
18+
#include "core/graph/graph_proto_serializer.h"
1619

1720
#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL
1821
#include "core/providers/utils/ort_graph_to_proto.h"
@@ -31,6 +34,7 @@ namespace test {
3134
// forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent
3235
// to a graph represented by the internal ORT GraphViewer class.
3336
static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph);
37+
static void Check_Graph_GetSubgraph(const OrtGraph& api_graph);
3438

3539
//
3640
// Tests
@@ -73,6 +77,16 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) {
7377
CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph());
7478
}
7579

80+
TEST(EpGraphTest, Check3LayerNestedSubgraphV2) {
81+
// The overall structure of this model is similar to the one used in "Check3LayerNestedSubgraph" test.
82+
// The model consists of a graph with subgraphs nested across three levels.
83+
// In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer).
84+
auto test_graph = TestGraph::Load(ORT_TSTR("testdata/three_layer_nested_subgraph_v2.onnx"));
85+
ASSERT_NE(test_graph, nullptr) << "Failed to load test model";
86+
87+
CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph());
88+
}
89+
7690
static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector<float>& output_data) {
7791
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
7892
Ort::SessionOptions sess_options;
@@ -474,6 +488,48 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span<const
474488
}
475489
}
476490

491+
// Checks the Graph_GetSubgraph C API
492+
static void Check_Graph_GetSubgraph(const OrtGraph& api_graph) {
493+
const OrtApi& ort_api = Ort::GetApi();
494+
495+
// Get all the nodes
496+
size_t num_nodes = 0;
497+
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&api_graph, &num_nodes));
498+
499+
std::vector<const OrtNode*> nodes(num_nodes);
500+
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size()));
501+
502+
// Select a half of nodes to create a OrtGraph
503+
size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1);
504+
std::vector<const OrtNode*> selected_nodes(num_selected_nodes);
505+
506+
for (size_t i = 0; i < num_selected_nodes; i++) {
507+
selected_nodes[i] = nodes[i];
508+
}
509+
510+
OrtGraph* sub_graph;
511+
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph));
512+
513+
// Convert OrtGraph/GraphViewer to ModelProto and dump it to disk.
514+
// If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw.
515+
const GraphViewer& sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer();
516+
std::unique_ptr<Model> model = std::make_unique<Model>(sub_graph_viewer.Name(), true, sub_graph_viewer.GetGraph().GetLogger());
517+
auto model_proto = std::make_unique<ONNX_NAMESPACE::ModelProto>(model->ToProto());
518+
GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast<ExecutionOrder>(1));
519+
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
520+
521+
const char* graph_name = nullptr;
522+
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name));
523+
std::string name = graph_name;
524+
name += "_half.onnx";
525+
526+
// Dump the graph for debugging
527+
// std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary);
528+
// model_proto->SerializeToOstream(&dump);
529+
530+
ort_api.ReleaseGraph(sub_graph);
531+
}
532+
477533
// Checks that the contents of the original GraphViewer matches the contents of the OrtGraph.
478534
// Uses the public C APIs to traverse the OrtGraph.
479535
static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) {
@@ -682,6 +738,9 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_
682738
}
683739
}
684740
}
741+
742+
// Check creating an OrtGraph from a subset of nodes in an OrtGraph
743+
Check_Graph_GetSubgraph(api_graph);
685744
}
686745

687746
} // namespace test
1.85 KB
Binary file not shown.

0 commit comments

Comments
 (0)