Skip to content

Commit 73c77c8

Browse files
zz002Zhenze Wangzhenzew
authored andcommitted
[VitisAI] Cache node subgraph when necessary (microsoft#22073)
### Description <!-- Describe your changes. --> [VitisAI] Cache node subgraph when necessary ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Zhenze Wang <[email protected]> Co-authored-by: zhenzew <[email protected]>
1 parent 41cc60a commit 73c77c8

File tree

10 files changed

+32
-9
lines changed

10 files changed

+32
-9
lines changed

cmake/onnxruntime_providers_vitisai.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
file(GLOB onnxruntime_providers_vitisai_cc_srcs CONFIGURE_DEPENDS
1313
"${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.cc"
1414
"${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.h"
15+
"${ONNXRUNTIME_ROOT}/core/providers/vitisai/include/vaip/*.h"
1516
"${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc"
1617
"${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h"
1718
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h"

onnxruntime/core/providers/shared_library/provider_interfaces.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,7 @@ struct ProviderHost {
996996
bool include_outer_scope_args,
997997
int execution_order) noexcept = 0;
998998
virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0;
999+
virtual IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const = 0;
9991000

10001001
// OpKernel
10011002
virtual const Node& OpKernel__Node(const OpKernel* p) = 0;

onnxruntime/core/providers/shared_library/provider_wrappedtypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,7 @@ class GraphViewer final {
10681068
g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args, execution_order);
10691069
}
10701070
const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); }
1071+
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return g_host->GraphViewer__GetSchemaRegistry(this); }
10711072

10721073
GraphViewer() = delete;
10731074
GraphViewer(const GraphViewer&) = delete;

onnxruntime/core/providers/vitisai/imp/global_api.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,11 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
444444
}
445445
};
446446
the_global_api.node_arg_external_location = vaip::node_arg_external_location;
447+
the_global_api.model_to_proto = [](onnxruntime::Model& model) { return model.ToProto().release(); };
448+
the_global_api.model_proto_serialize_as_string = [](ONNX_NAMESPACE::ModelProto& model_proto) {
449+
return vaip_core::DllSafe(model_proto.SerializeAsString());
450+
};
451+
the_global_api.model_proto_delete = [](ONNX_NAMESPACE::ModelProto* p) { delete p; };
447452
if (!s_library_vitisaiep.vaip_get_version) {
448453
return reinterpret_cast<vaip_core::OrtApiForVaip*>(&(the_global_api.host_));
449454
} else {

onnxruntime/core/providers/vitisai/include/vaip/custom_op.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,18 @@ class ExecutionProvider {
2525
virtual DllSafe<std::vector<std::string>> get_meta_def_nodes() const = 0;
2626
virtual DllSafe<std::vector<std::string>>
2727
get_meta_def_constant_initializer() const = 0;
28+
virtual bool get_meta_def_fallback_CPU() const { return false; };
2829
virtual std::unique_ptr<CustomOp> compile() const = 0;
2930

3031
public:
31-
inline void set_fused_node(const onnxruntime::Node* fused_node) {
32-
fused_node_ = fused_node;
33-
}
34-
inline const onnxruntime::Node* get_fused_node() const {
35-
return fused_node_;
36-
}
32+
inline void set_fused_node(const onnxruntime::Node* fused_node) { fused_node_ = fused_node; }
33+
inline const onnxruntime::Node* get_fused_node() const { return fused_node_; }
34+
inline void set_model(onnxruntime::Model* model) { model_ = model; }
35+
inline onnxruntime::Model* get_model() const { return model_; }
3736

3837
private:
3938
const onnxruntime::Node* fused_node_ = nullptr;
39+
onnxruntime::Model* model_ = nullptr;
4040
};
4141

4242
class CustomOp {

onnxruntime/core/providers/vitisai/include/vaip/my_ort.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ struct NodeAttributes;
2020
namespace ONNX_NAMESPACE {
2121
struct AttributeProto;
2222
struct TensorProto;
23+
struct ModelProto;
2324
#ifndef USE_VITISAI
2425
enum TensorProto_DataType : int {
2526
TensorProto_DataType_UNDEFINED = 0,
@@ -70,6 +71,7 @@ enum AttributeProto_AttributeType : int {
7071
namespace vaip_core {
7172
class GraphHolder;
7273
using ONNX_NAMESPACE::AttributeProto;
74+
using ONNX_NAMESPACE::ModelProto;
7375
using ONNX_NAMESPACE::TensorProto;
7476
using onnxruntime::Graph;
7577
using onnxruntime::GraphViewer;

onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct OrtApi;
1313

1414
namespace vaip_core {
1515

16-
#define VAIP_ORT_API_MAJOR (10u)
16+
#define VAIP_ORT_API_MAJOR (11u)
1717
#define VAIP_ORT_API_MINOR (0u)
1818
#define VAIP_ORT_API_PATCH (0u)
1919
struct OrtApiForVaip {
@@ -231,6 +231,9 @@ struct OrtApiForVaip {
231231
gsl::span<const NodeArg* const> inputs); // [92]
232232
int (*node_arg_external_location)(const Graph& graph, const NodeArg& node_arg, std::string& file, size_t& offset, size_t& size, size_t& checksum); // [93]
233233
void (*session_option_configuration)(void* mmap, void* session_option, void (*push)(void* mmap, const char* name, const char* value)); // [94]
234+
ModelProto* (*model_to_proto)(Model& model); // [95]
235+
DllSafe<std::string> (*model_proto_serialize_as_string)(ModelProto& model_proto); // [96]
236+
void (*model_proto_delete)(ModelProto* p); // [97]
234237
};
235238

236239
#ifndef USE_VITISAI

onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,17 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector<FusedNodeAndG
7676
auto& attrs = fused_node_graph.fused_node.get().GetAttributes();
7777
assert(attrs.count("index"));
7878
size_t index = attrs.at("index").i();
79-
(**this->execution_providers_)[index]->set_fused_node(&fused_node_graph.fused_node.get());
79+
auto& ep = (**this->execution_providers_)[index];
80+
ep->set_fused_node(&fused_node_graph.fused_node.get());
81+
if (ep->get_meta_def_fallback_CPU()) {
82+
auto& subgraph = fused_node_graph.filtered_graph.get();
83+
auto& logger = logging::LoggingManager::DefaultLogger();
84+
auto model_proto = subgraph.CreateModel(logger)->ToProto();
85+
subgraph.ToProto(*model_proto->mutable_graph(), true, true);
86+
auto local_registries = IOnnxRuntimeOpSchemaRegistryList{subgraph.GetSchemaRegistry()};
87+
auto model = Model::Create(std::move(*model_proto), subgraph.ModelPath(), &local_registries, logger);
88+
ep->set_model(model.release());
89+
}
8090
compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) {
8191
auto* p = (**this->execution_providers_)[index]->compile().release();
8292
*state = p;

onnxruntime/core/providers/vitisai/vitisai_execution_provider.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ class VitisAIExecutionProvider : public IExecutionProvider {
5050
ProviderOptions info_;
5151
std::vector<OrtCustomOpDomain*> custom_op_domains_;
5252
std::shared_ptr<KernelRegistry> registry_;
53-
std::set<std::string> vitisai_optypes_;
5453
// EP context related.
5554
bool ep_ctx_enabled_ = false;
5655
bool ep_ctx_embed_mode_ = true;

onnxruntime/core/session/provider_bridge_ort.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,7 @@ struct ProviderHostImpl : ProviderHost {
12121212
GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args, static_cast<ExecutionOrder>(execution_order));
12131213
}
12141214
const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); }
1215+
IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const override { return p->GetSchemaRegistry(); }
12151216

12161217
// OpKernel (direct)
12171218
const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); }

0 commit comments

Comments
 (0)