diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c860e0794abed..8861d14cfa512 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -795,9 +795,6 @@ typedef struct OrtCompileApi OrtCompileApi; struct OrtEpApi; typedef struct OrtEpApi OrtEpApi; -struct OrtNodeComputeInfo; -typedef struct OrtNodeComputeInfo OrtNodeComputeInfo; - /** \brief The helper interface to get the right version of OrtApi * * Get a pointer to this structure through ::OrtGetApiBase diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 68b6992177b0d..b3dac9f891cc1 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -12,6 +12,44 @@ ORT_RUNTIME_CLASS(EpFactory); ORT_RUNTIME_CLASS(EpGraphSupportInfo); ORT_RUNTIME_CLASS(NodeComputeContext); +struct OrtNodeFusionOptions; +typedef struct OrtNodeFusionOptions OrtNodeFusionOptions; + +struct OrtNodeComputeInfo; +typedef struct OrtNodeComputeInfo OrtNodeComputeInfo; + +/** + * \brief The OrtNodeFusionOptions struct specifies options for fusing nodes supported by an execution provider. + * + * Refer to OrtEpApi::EpGraphSupportInfo_AddNodesToFuse. + * + * \since Version 1.23. + */ +struct OrtNodeFusionOptions { + /** \brief The ONNX Runtime version the OrtNodeFusionOptions was compiled with. + * + * Implementation should set to ORT_API_VERSION. + * ORT will use this to ensure it does not use members that were not available when the EP library was compiled. + * + * \since Version 1.23. + */ + uint32_t ort_version_supported; + + /** \brief If set to true, specify that the execution provider does not require ONNX Runtime to provide constant + * initializers as inputs to the fused node during model inference. This is used when the execution + * provider saves a copy of constant initializers, and allows ONNX Runtime to release constant initializers that + * are not used by any execution provider. + * + * If not specified, defaults to false. That is, ONNX Runtime provides constant initializers as inputs to + * the fused node by default. + * + * \since Version 1.23. + */ + bool drop_constant_initializers; + + // const OrtNode* fused_node_schema; +}; + /** * \brief The OrtNodeComputeInfo struct provides functions that an OrtEp implements to specify the compute * function for a compiled OrtGraph instance. @@ -21,7 +59,7 @@ struct OrtNodeComputeInfo { /** \brief The ONNX Runtime version the OrtNodeComputeInfo was compiled with. * * Implementation should set to ORT_API_VERSION. - * ORT will use this to ensure it does not call functions that were not available when the library was compiled. + * ORT will use this to ensure it does not call functions that were not available when the EP library was compiled. * * \since Version 1.23. */ @@ -87,9 +125,6 @@ struct OrtEpApi { ORT_CLASS_RELEASE(EpDevice); /** \brief Specify nodes that are supported by an OrtEp and should be fused into one node. - * - * IMPORTANT: This is not the final version of this API function. This is currently experimental but will - * be stabilized by the ONNX Runtime 1.23 release. * * Because the nodes will be fused into one "fused node", there must not exist an unsupported node in * a path between two of the provided nodes. Otherwise, the graph will become invalid. @@ -100,14 +135,15 @@ struct OrtEpApi { * \param[in] graph_support_info OrtEpGraphSupportInfo instance to which to add the supported nodes. * \param[in] nodes Array of nodes supported by the EP that should be fused/compiled. * \param[in] num_nodes The number of supported nodes. + * \param[in] node_fusion_options Optional node fusion options. Ignored if set to NULL. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info, - _In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes - /*, OrtFusedNodeSchema* optional_fused_node_schema, OrtNodesToOptimizeInfo* nodes_to_opt*/); + _In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes, + _In_opt_ const OrtNodeFusionOptions* node_fusion_options); /** \brief Specify a node that is supported by an OrtEp and should be run with a registered EP kernel. * diff --git a/onnxruntime/core/session/abi_ep_types.cc b/onnxruntime/core/session/abi_ep_types.cc index 719f55b4e6b38..14764251898aa 100644 --- a/onnxruntime/core/session/abi_ep_types.cc +++ b/onnxruntime/core/session/abi_ep_types.cc @@ -10,7 +10,8 @@ #include "core/graph/ep_api_types.h" #include "core/session/abi_devices.h" -onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span nodes) { +onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span nodes, + const OrtNodeFusionOptions* optional_fusion_options) { std::vector ep_nodes; ep_nodes.reserve(nodes.size()); @@ -20,7 +21,8 @@ onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span ep_nodes; ep_nodes.push_back(onnxruntime::EpNode::ToInternal(node)); node_groupings.emplace_back(NodeGroupingKind::kSingleAssignedNode, std::move(ep_nodes)); - return onnxruntime::Status::OK(); } diff --git a/onnxruntime/core/session/abi_ep_types.h b/onnxruntime/core/session/abi_ep_types.h index b19a03a57a78a..eb68d79a24279 100644 --- a/onnxruntime/core/session/abi_ep_types.h +++ b/onnxruntime/core/session/abi_ep_types.h @@ -30,16 +30,19 @@ struct OrtEpGraphSupportInfo { // A grouping of supported nodes that should be handled in a single ComputeCapability. struct NodeGrouping { - NodeGrouping(NodeGroupingKind kind, std::vector&& nodes) - : kind(kind), nodes(std::move(nodes)) {} + NodeGrouping(NodeGroupingKind kind, std::vector&& nodes, + const OrtNodeFusionOptions& fusion_options = {}) + : kind(kind), nodes(std::move(nodes)), fusion_options(fusion_options) {} NodeGroupingKind kind = NodeGroupingKind::kInvalidGrouping; std::vector nodes; + OrtNodeFusionOptions fusion_options = {}; }; explicit OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph) : ort_graph(graph) {} - onnxruntime::Status AddNodesToFuse(gsl::span nodes); + onnxruntime::Status AddNodesToFuse(gsl::span nodes, + const OrtNodeFusionOptions* node_fusion_options = nullptr); onnxruntime::Status AddSingleNode(const OrtNode* node); const onnxruntime::EpGraph& ort_graph; diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/ep_api.cc index ffb5a286730ba..05fb46526ac1d 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/ep_api.cc @@ -44,7 +44,8 @@ ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device) { } ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* ort_graph_support_info, - _In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes) { + _In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes, + _In_opt_ const OrtNodeFusionOptions* node_fusion_options) { API_IMPL_BEGIN if (ort_graph_support_info == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid OrtGraph instance"); @@ -55,7 +56,7 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInf } gsl::span nodes_span(nodes, nodes + num_nodes); - ORT_API_RETURN_IF_STATUS_NOT_OK(ort_graph_support_info->AddNodesToFuse(nodes_span)); + ORT_API_RETURN_IF_STATUS_NOT_OK(ort_graph_support_info->AddNodesToFuse(nodes_span, node_fusion_options)); return nullptr; API_IMPL_END } diff --git a/onnxruntime/core/session/ep_api.h b/onnxruntime/core/session/ep_api.h index 84c8781a70adb..90c617e44260b 100644 --- a/onnxruntime/core/session/ep_api.h +++ b/onnxruntime/core/session/ep_api.h @@ -18,7 +18,8 @@ ORT_API_STATUS_IMPL(CreateEpDevice, _In_ OrtEpFactory* ep_factory, ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device); ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info, - _In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes); + _In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes, + _In_opt_ const OrtNodeFusionOptions* node_fusion_options); ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddSingleNode, _In_ OrtEpGraphSupportInfo* graph_support_info, _In_ const OrtNode* node); ORT_API(const char*, NodeComputeContext_NodeName, _In_ const OrtNodeComputeContext* context); diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc index ebd74dd51774c..a27e62218ad52 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc @@ -138,6 +138,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie } else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) { std::unordered_set node_set; node_set.reserve(node_grouping.nodes.size()); + for (const EpNode* ep_node : node_grouping.nodes) { node_set.insert(&ep_node->GetInternalNode()); } @@ -151,7 +152,8 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // unsupported nodes in any path between supported nodes. std::vector> capabilities = utils::CreateSupportedPartitions( graph_viewer, node_set, /*stop_ops*/ {}, PluginEpMetaDefNameFunctor(generator, graph_viewer, this->Type()), - this->Type(), this->Type(), /*node_unit_map*/ nullptr); + this->Type(), this->Type(), /*node_unit_map*/ nullptr, + node_grouping.fusion_options.drop_constant_initializers); if (capabilities.size() > 1) { LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() set nodes that cannot be fused together. " diff --git a/onnxruntime/test/autoep/library/example_plugin_ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep.cc index 9978189267a40..581198c4a945c 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep.cc @@ -18,69 +18,108 @@ struct ExampleEp; /// Example implementation of ONNX Mul. Does not handle many things like broadcasting. /// struct MulKernel { - MulKernel(const OrtApi& ort_api, const OrtLogger& logger) : ort_api(ort_api), logger(logger) {} + MulKernel(const OrtApi& ort_api, const OrtLogger& logger, + const std::unordered_map& float_initializers, + std::string input0_name, std::string input1_name) + : ort_api(ort_api), + logger(logger), + float_initializers(float_initializers), + input0_name(input0_name), + input1_name(input1_name) {} + + const FloatInitializer* TryGetSavedInitializer(const std::string& name) const { + auto iter = float_initializers.find(name); + return iter != float_initializers.end() ? &iter->second : nullptr; + } + + OrtStatus* GetInputDataAndShape(OrtKernelContext* kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const { + const OrtValue* input = nullptr; + RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_context, index, &input)); + + OrtTensorTypeAndShapeInfo* type_shape = nullptr; + DeferOrtRelease release_type(&type_shape, ort_api.ReleaseTensorTypeAndShapeInfo); + + RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input, &type_shape)); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); + RETURN_IF(elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ort_api, "Expected float32 inputs"); + + size_t num_elems = 0; + RETURN_IF_ERROR(ort_api.GetTensorShapeElementCount(type_shape, &num_elems)); + + size_t num_dims = 0; + RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape, &num_dims)); + + shape.resize(num_dims, 0); + RETURN_IF_ERROR(ort_api.GetDimensions(type_shape, shape.data(), shape.size())); + + const float* raw_data = nullptr; + RETURN_IF_ERROR(ort_api.GetTensorMutableData(const_cast(input), (void**)&raw_data)); // No const-correct API? + + data = gsl::span(raw_data, num_elems); + return nullptr; + } OrtStatus* Compute(OrtKernelContext* kernel_context) { RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); + gsl::span input0; + gsl::span input1; + std::vector shape0; + std::vector shape1; + size_t num_inputs = 0; RETURN_IF_ERROR(ort_api.KernelContext_GetInputCount(kernel_context, &num_inputs)); - RETURN_IF(num_inputs != 2, ort_api, "Expected 2 inputs for MulKernel"); + + if (num_inputs == 2) { + // Both inputs are non-constant. Get them from ORT's KernelContext. + RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input0, shape0)); + RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 1, input1, shape1)); + } else if (num_inputs == 1) { + // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. + // Get the constant input from the initializers saved by the EP. + // Refer to "NodeFusionOptions_DropConstantInitializers()". + + if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { + RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input1, shape1)); + input0 = gsl::span(const_input0->data); + shape0 = const_input0->shape; + } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { + RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input0, shape0)); + input1 = gsl::span(const_input1->data); + shape1 = const_input1->shape; + } + } else { + // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) + // are disabled. + const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); + const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); + RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, + "Expected 2 initializer inputs to be saved by EP"); + + input0 = gsl::span(const_input0->data); + input1 = gsl::span(const_input1->data); + shape0 = const_input0->shape; + shape1 = const_input1->shape; + } + + RETURN_IF(shape0 != shape1, ort_api, "Expected same dimensions for both inputs"); // No broadcasting. size_t num_outputs = 0; RETURN_IF_ERROR(ort_api.KernelContext_GetOutputCount(kernel_context, &num_outputs)); RETURN_IF(num_outputs != 1, ort_api, "Expected 1 output for MulKernel"); - const OrtValue* input0 = nullptr; - const OrtValue* input1 = nullptr; - RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_context, 0, &input0)); - RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_context, 1, &input1)); - - OrtTensorTypeAndShapeInfo* type_shape0 = nullptr; - OrtTensorTypeAndShapeInfo* type_shape1 = nullptr; - DeferOrtRelease release_type0(&type_shape0, ort_api.ReleaseTensorTypeAndShapeInfo); - DeferOrtRelease release_type1(&type_shape1, ort_api.ReleaseTensorTypeAndShapeInfo); - - RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input0, &type_shape0)); - RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input1, &type_shape1)); - - ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape0, &elem_type)); - RETURN_IF(elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ort_api, "Expected float32 inputs"); - - size_t num_dims0 = 0; - size_t num_dims1 = 0; - RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape0, &num_dims0)); - RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape1, &num_dims1)); - RETURN_IF((num_dims0 == 0) || (num_dims1 == 0), ort_api, "Input has 0 dimensions"); - RETURN_IF(num_dims0 != num_dims1, ort_api, "Expected same dimensions for both inputs"); // No broadcasting - - std::vector dims0(num_dims0, 0); - std::vector dims1(num_dims1, 0); - RETURN_IF_ERROR(ort_api.GetDimensions(type_shape0, dims0.data(), dims0.size())); - RETURN_IF_ERROR(ort_api.GetDimensions(type_shape1, dims1.data(), dims1.size())); - RETURN_IF(dims0 != dims1, ort_api, "Expected same dimensions for both inputs"); // No broadcasting. - - const float* input_data0 = nullptr; - const float* input_data1 = nullptr; - RETURN_IF_ERROR(ort_api.GetTensorMutableData(const_cast(input0), (void**)&input_data0)); // No const-correct API? - RETURN_IF_ERROR(ort_api.GetTensorMutableData(const_cast(input1), (void**)&input_data1)); - OrtValue* output = nullptr; - RETURN_IF_ERROR(ort_api.KernelContext_GetOutput(kernel_context, 0, dims0.data(), dims0.size(), &output)); - float* output_data = nullptr; + RETURN_IF_ERROR(ort_api.KernelContext_GetOutput(kernel_context, 0, shape0.data(), shape0.size(), &output)); RETURN_IF_ERROR(ort_api.GetTensorMutableData(output, reinterpret_cast(&output_data))); - int64_t num_elems = 1; - for (int64_t dim : dims0) { - RETURN_IF(dim < 0, ort_api, "Invalid dimension: negative value detected"); - num_elems *= dim; - } - - for (size_t i = 0; i < static_cast(num_elems); ++i) { - output_data[i] = input_data0[i] * input_data1[i]; + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = input0[i] * input1[i]; } return nullptr; @@ -88,6 +127,9 @@ struct MulKernel { const OrtApi& ort_api; const OrtLogger& logger; + const std::unordered_map& float_initializers; + std::string input0_name; + std::string input1_name; }; /// @@ -142,6 +184,55 @@ struct ExampleEp : OrtEp, ApiPtrs { // Clean up the execution provider } + OrtStatus* SaveConstantInitializers(const OrtGraph* graph) { + OrtArrayOfConstObjects* initializers = nullptr; + DeferOrtRelease release_initializers(&initializers, ort_api.ReleaseArrayOfConstObjects); + size_t num_initializers = 0; + + RETURN_IF_ERROR(ort_api.Graph_GetInitializers(graph, &initializers)); + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetSize(initializers, &num_initializers)); + + for (size_t i = 0; i < num_initializers; ++i) { + const OrtValueInfo* initializer = nullptr; + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(initializers, i, + reinterpret_cast(&initializer))); + + bool is_constant = false; + RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(initializer, &is_constant)); + + if (is_constant) { + const char* name = nullptr; + const OrtValue* value = nullptr; + OrtTensorTypeAndShapeInfo* type_shape = nullptr; + DeferOrtRelease release_type(&type_shape, ort_api.ReleaseTensorTypeAndShapeInfo); + size_t num_elems = 0; + + RETURN_IF_ERROR(ort_api.GetValueInfoName(initializer, &name)); + RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer, &value)); + RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(value, &type_shape)); + RETURN_IF_ERROR(ort_api.GetTensorShapeElementCount(type_shape, &num_elems)); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); + RETURN_IF(elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ort_api, "Expected float32 initializers"); + + size_t num_dims = 0; + RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape, &num_dims)); + + std::vector dims(num_dims, 0); + RETURN_IF_ERROR(ort_api.GetDimensions(type_shape, dims.data(), dims.size())); + + const float* data = nullptr; + RETURN_IF_ERROR(ort_api.GetTensorMutableData(const_cast(value), (void**)&data)); + + FloatInitializer ep_initializer = {std::move(dims), std::vector(data, data + num_elems)}; + this->float_initializers.emplace(name, std::move(ep_initializer)); + } + } + + return nullptr; + } + static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) { const auto* ep = static_cast(this_ptr); return ep->name_.c_str(); @@ -206,8 +297,19 @@ struct ExampleEp : OrtEp, ApiPtrs { break; } } + + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + + // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers + // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. + // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use + // during inference. + node_fusion_options.drop_constant_initializers = true; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, supported_nodes.data(), - supported_nodes.size())); + supported_nodes.size(), &node_fusion_options)); return nullptr; } @@ -216,39 +318,62 @@ struct ExampleEp : OrtEp, ApiPtrs { _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, _Out_writes_(count) OrtNode** ep_context_nodes) { ExampleEp* ep = static_cast(this_ptr); + const OrtApi& ort_api = ep->ort_api; if (count != 1) { return ep->ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single graph"); } + // In GetCapability(), this EP specified that it doesn't need ORT to provide constant initializers during inference. + // So, this EP saves constant initializers so that they're available during inference, but an actual EP + // implementation could transfer the weights to device memory. + ep->SaveConstantInitializers(graphs[0]); + OrtArrayOfConstObjects* nodes_array = nullptr; - DeferOrtRelease release_nodes(&nodes_array, ep->ort_api.ReleaseArrayOfConstObjects); + DeferOrtRelease release_nodes(&nodes_array, ort_api.ReleaseArrayOfConstObjects); size_t num_nodes = 0; - RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graphs[0], &nodes_array)); - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); + RETURN_IF_ERROR(ort_api.Graph_GetNodes(graphs[0], &nodes_array)); + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); if (num_nodes != 1) { - return ep->ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); + return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); } const OrtNode* node_to_compile = nullptr; - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetElementAt(nodes_array, 0, - reinterpret_cast(&node_to_compile))); + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(nodes_array, 0, + reinterpret_cast(&node_to_compile))); const char* node_op_type = nullptr; - RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node_to_compile, &node_op_type)); + RETURN_IF_ERROR(ort_api.Node_GetOperatorType(node_to_compile, &node_op_type)); if (std::strncmp(node_op_type, "Mul", 4) != 0) { - return ep->ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); + return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); } - // Now we know we're compiling a single Mul node. + // Now we know we're compiling a single Mul node. Create a computation kernel. + OrtArrayOfConstObjects* inputs = nullptr; + DeferOrtRelease release_inputs(&inputs, ep->ort_api.ReleaseArrayOfConstObjects); + + RETURN_IF_ERROR(ep->ort_api.Node_GetInputs(node_to_compile, &inputs)); + const OrtValueInfo* input0 = nullptr; + const OrtValueInfo* input1 = nullptr; + + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(inputs, 0, reinterpret_cast(&input0))); + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(inputs, 1, reinterpret_cast(&input1))); + + const char* input0_name = nullptr; + const char* input1_name = nullptr; + RETURN_IF_ERROR(ort_api.GetValueInfoName(input0, &input0_name)); + RETURN_IF_ERROR(ort_api.GetValueInfoName(input1, &input1_name)); + // Associate the name of the fused node with our MulKernel. const char* fused_node_name = nullptr; RETURN_IF_ERROR(ep->ort_api.Node_GetName(fused_nodes[0], &fused_node_name)); - ep->kernels.emplace(std::string(fused_node_name), std::make_unique(ep->ort_api, ep->logger_)); + ep->kernels.emplace(std::string(fused_node_name), std::make_unique(ep->ort_api, ep->logger_, + ep->float_initializers, + input0_name, input1_name)); // Update the OrtNodeComputeInfo associated with the graph. auto node_compute_info = std::make_unique(*ep); @@ -347,6 +472,7 @@ struct ExampleEp : OrtEp, ApiPtrs { Config config_{}; const OrtLogger& logger_; std::unordered_map> kernels; + std::unordered_map float_initializers; }; // diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h index ae0a86bbb7222..d1243b4a6aeda 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h @@ -49,6 +49,11 @@ struct DeferOrtRelease { std::function release_func_ = nullptr; }; +struct FloatInitializer { + std::vector shape; + std::vector data; +}; + // Returns an entry in the session option configurations, or a default value if not present. OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessionOptions& session_options, const char* config_key, const std::string& default_val,