7
7
#include < gsl/gsl>
8
8
#include < memory>
9
9
#include < vector>
10
+ #include < fstream>
10
11
11
12
#include " core/common/common.h"
12
13
#include " core/framework/tensorprotoutils.h"
13
14
#include " core/framework/tensor_type_and_shape.h"
14
15
#include " core/framework/onnxruntime_typeinfo.h"
15
16
#include " core/session/onnxruntime_cxx_api.h"
17
+ #include " core/graph/ep_api_types.h"
18
+ #include " core/graph/graph_proto_serializer.h"
16
19
17
20
#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL
18
21
#include " core/providers/utils/ort_graph_to_proto.h"
@@ -31,6 +34,7 @@ namespace test {
31
34
// forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent
32
35
// to a graph represented by the internal ORT GraphViewer class.
33
36
static void CheckGraphCApi (const GraphViewer& graph_viewer, const OrtGraph& api_graph);
37
+ static void Check_Graph_GetSubgraph (const OrtGraph& api_graph);
34
38
35
39
//
36
40
// Tests
@@ -73,6 +77,16 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) {
73
77
CheckGraphCApi (test_graph->GetGraphViewer (), test_graph->GetOrtGraph ());
74
78
}
75
79
80
+ TEST (EpGraphTest, Check3LayerNestedSubgraphV2) {
81
+ // The overall structure of this model is similar to the one used in "Check3LayerNestedSubgraph" test.
82
+ // The model consists of a graph with subgraphs nested across three levels.
83
+ // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer).
84
+ auto test_graph = TestGraph::Load (ORT_TSTR (" testdata/three_layer_nested_subgraph_v2.onnx" ));
85
+ ASSERT_NE (test_graph, nullptr ) << " Failed to load test model" ;
86
+
87
+ CheckGraphCApi (test_graph->GetGraphViewer (), test_graph->GetOrtGraph ());
88
+ }
89
+
76
90
static void RunMNISTModel (const ORTCHAR_T* model_path, std::vector<float >& output_data) {
77
91
auto memory_info = Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeCPU);
78
92
Ort::SessionOptions sess_options;
@@ -474,6 +488,48 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span<const
474
488
}
475
489
}
476
490
491
+ // Checks the Graph_GetSubgraph C API
492
+ static void Check_Graph_GetSubgraph (const OrtGraph& api_graph) {
493
+ const OrtApi& ort_api = Ort::GetApi ();
494
+
495
+ // Get all the nodes
496
+ size_t num_nodes = 0 ;
497
+ ASSERT_ORTSTATUS_OK (ort_api.Graph_GetNumNodes (&api_graph, &num_nodes));
498
+
499
+ std::vector<const OrtNode*> nodes (num_nodes);
500
+ ASSERT_ORTSTATUS_OK (ort_api.Graph_GetNodes (&api_graph, nodes.data (), nodes.size ()));
501
+
502
+ // Select a half of nodes to create a OrtGraph
503
+ size_t num_selected_nodes = std::max ((nodes.size () >> 1 ), (size_t )1 );
504
+ std::vector<const OrtNode*> selected_nodes (num_selected_nodes);
505
+
506
+ for (size_t i = 0 ; i < num_selected_nodes; i++) {
507
+ selected_nodes[i] = nodes[i];
508
+ }
509
+
510
+ OrtGraph* sub_graph;
511
+ ASSERT_ORTSTATUS_OK (ort_api.Graph_GetGraphView (&api_graph, selected_nodes.data (), selected_nodes.size (), &sub_graph));
512
+
513
+ // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk.
514
+ // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw.
515
+ const GraphViewer& sub_graph_viewer = EpGraph::ToInternal (sub_graph)->GetGraphViewer ();
516
+ std::unique_ptr<Model> model = std::make_unique<Model>(sub_graph_viewer.Name (), true , sub_graph_viewer.GetGraph ().GetLogger ());
517
+ auto model_proto = std::make_unique<ONNX_NAMESPACE::ModelProto>(model->ToProto ());
518
+ GraphViewerToProto (sub_graph_viewer, *model_proto->mutable_graph (), true , true , static_cast <ExecutionOrder>(1 ));
519
+ model_proto->set_ir_version (ONNX_NAMESPACE::Version::IR_VERSION);
520
+
521
+ const char * graph_name = nullptr ;
522
+ ASSERT_ORTSTATUS_OK (ort_api.Graph_GetName (&api_graph, &graph_name));
523
+ std::string name = graph_name;
524
+ name += " _half.onnx" ;
525
+
526
+ // Dump the graph for debugging
527
+ // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary);
528
+ // model_proto->SerializeToOstream(&dump);
529
+
530
+ ort_api.ReleaseGraph (sub_graph);
531
+ }
532
+
477
533
// Checks that the contents of the original GraphViewer matches the contents of the OrtGraph.
478
534
// Uses the public C APIs to traverse the OrtGraph.
479
535
static void CheckGraphCApi (const GraphViewer& graph_viewer, const OrtGraph& api_graph) {
@@ -682,6 +738,9 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_
682
738
}
683
739
}
684
740
}
741
+
742
+ // Check creating an OrtGraph from a subset of nodes in an OrtGraph
743
+ Check_Graph_GetSubgraph (api_graph);
685
744
}
686
745
687
746
} // namespace test
0 commit comments