Skip to content
Merged
75 changes: 75 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,16 @@ struct OrtEpApi {
ORT_API_T(const char*, NodeComputeContext_NodeName, _In_ const OrtNodeComputeContext* context);
};

/**
* \brief The data layout type that is preferred by an EP.
* \since Version 1.23.
*/
typedef enum OrtEpDataLayout {
NCHW = 0,
NHWC,
NCHWC,
} OrtEpDataLayout;

/**
* \brief The OrtEp struct provides functions to implement for an execution provider.
* \since Version 1.22.
Expand Down Expand Up @@ -217,6 +227,71 @@ struct OrtEp {
void(ORT_API_CALL* ReleaseNodeComputeInfos)(_In_ OrtEp* this_ptr,
OrtNodeComputeInfo** node_compute_infos,
_In_ size_t num_node_compute_infos);

/** \brief Get the EP's preferred data layout.
*
* \note Implementation of this function is optional.
* If not implemented, ORT will assume that this EP prefers the data layout `OrtEpDataLayout::NCHW`.
*
* \param[in] this_ptr The OrtEp instance.
* \param[out] preferred_data_layout The EP's preferred data layout.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
OrtStatus*(ORT_API_CALL* GetPreferredDataLayout)(_In_ OrtEp* this_ptr,
_Out_ OrtEpDataLayout* preferred_data_layout);

/** \brief Set dynamic options on this EP.
*
* \param[in] this_ptr The OrtEp instance.
* \param[in] option_keys The dynamic option keys.
* \param[in] option_values The dynamic option values.
* \param[in] num_options The number of dynamic options.
*
* \note Implementation of this function is optional.
* An EP should only implement this if it needs to handle any dynamic options.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
OrtStatus*(ORT_API_CALL* SetDynamicOptions)(_In_ OrtEp* this_ptr,
_In_reads_(num_options) const char* const* option_keys,
_In_reads_(num_options) const char* const* option_values,
_In_ size_t num_options);

/** \brief Called by ORT to notify the EP of the start of a run.
*
* \param[in] this_ptr The OrtEp instance.
* \param[in] run_options The run options for this run.
*
* \note Implementation of this function is optional.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
OrtStatus*(ORT_API_CALL* OnRunStart)(_In_ OrtEp* this_ptr,
_In_ const OrtRunOptions* run_options);

/** \brief Called by ORT to notify the EP of the end of a run.
*
* \param[in] this_ptr The OrtEp instance.
* \param[in] run_options The run options for this run.
* \param[in] sync_stream Whether any associated stream should be synchronized during this call.
* Only applicable if there is such a stream.
*
* \note Implementation of this function is optional.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
OrtStatus*(ORT_API_CALL* OnRunEnd)(_In_ OrtEp* this_ptr,
_In_ const OrtRunOptions* run_options,
_In_ bool sync_stream);
};

/** \brief The function signature that ORT will call to create OrtEpFactory instances.
Expand Down
59 changes: 57 additions & 2 deletions onnxruntime/core/session/ep_plugin_provider_interfaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ Status PluginExecutionProvider::FusedNodeState::AddFusedNode(const Node& fused_n
return Status::OK();
}

common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_infos) {
Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_infos) {
const logging::Logger* logger = GetLogger();
const size_t num_graphs = fused_nodes_and_graphs.size();
std::vector<std::unique_ptr<EpGraph>> api_graphs_holder;
Expand Down Expand Up @@ -270,4 +270,59 @@ common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGr

return Status::OK();
}

DataLayout PluginExecutionProvider::GetPreferredLayout() const {
if (ort_ep_->GetPreferredDataLayout == nullptr) {
return Base::GetPreferredLayout();
}

OrtEpDataLayout api_data_layout{};

ORT_THROW_IF_ERROR(ToStatusAndRelease(ort_ep_->GetPreferredDataLayout(ort_ep_.get(), &api_data_layout)));

switch (api_data_layout) {
case OrtEpDataLayout::NCHW:
return DataLayout::NCHW;

case OrtEpDataLayout::NHWC:
return DataLayout::NHWC;

case OrtEpDataLayout::NCHWC:
return DataLayout::NCHWC;

default:
ORT_THROW("OrtEp::GetPreferredDataLayout() returned an invalid data layout: ",
static_cast<int>(api_data_layout));
}
}

Status PluginExecutionProvider::OnRunStart(const RunOptions& run_options) {
if (ort_ep_->OnRunStart == nullptr) {
return Base::OnRunStart(run_options);
}

return ToStatusAndRelease(ort_ep_->OnRunStart(ort_ep_.get(), &run_options));
}

Status PluginExecutionProvider::OnRunEnd(bool sync_stream, const RunOptions& run_options) {
if (ort_ep_->OnRunEnd == nullptr) {
return Base::OnRunEnd(sync_stream, run_options);
}

return ToStatusAndRelease(ort_ep_->OnRunEnd(ort_ep_.get(), &run_options, sync_stream));
}

Status PluginExecutionProvider::SetEpDynamicOptions(gsl::span<const char* const> keys,
gsl::span<const char* const> values) {
if (ort_ep_->SetDynamicOptions == nullptr) {
return Base::SetEpDynamicOptions(keys, values);
}

ORT_RETURN_IF_NOT(keys.size() == values.size(),
"The number of keys (", keys.size(), ") and number of values (", values.size(),
") must be the same.");

return ToStatusAndRelease(ort_ep_->SetDynamicOptions(ort_ep_.get(), keys.data(), values.data(), keys.size()));
}

} // namespace onnxruntime
16 changes: 14 additions & 2 deletions onnxruntime/core/session/ep_plugin_provider_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ using UniqueOrtEp = std::unique_ptr<OrtEp, OrtEpDeleter>;
/// IExecutionProvider that wraps an instance of OrtEp.
/// </summary>
class PluginExecutionProvider : public IExecutionProvider {
private:
using Base = IExecutionProvider;

public:
explicit PluginExecutionProvider(UniqueOrtEp ep);
~PluginExecutionProvider();
Expand All @@ -68,8 +71,17 @@ class PluginExecutionProvider : public IExecutionProvider {
const GraphOptimizerRegistry& graph_optimizer_registry,
IResourceAccountant* resource_accountant = nullptr) const override;

common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) override;

DataLayout GetPreferredLayout() const override;

Status OnRunStart(const RunOptions& run_options) override;

Status OnRunEnd(bool sync_stream, const RunOptions& run_options) override;

Status SetEpDynamicOptions(gsl::span<const char* const> keys,
gsl::span<const char* const> values) override;

private:
struct FusedNodeState {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ ORT_API_STATUS_IMPL(OrtApis::SetEpDynamicOptions, _Inout_ OrtSession* sess,
Status status;

if (kv_len == 0) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "no imputs were passed");
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "no inputs were passed");
} else {
status = session->SetEpDynamicOptions(keys_span,
values_span);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/autoep/library/example_plugin_ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ static OrtStatus* IsFloatTensor(const OrtApi& ort_api, const OrtValueInfo* value
/// </summary>
struct ExampleEp : OrtEp, ApiPtrs {
ExampleEp(ApiPtrs apis, const std::string& name, const OrtSessionOptions& session_options, const OrtLogger& logger)
: ApiPtrs(apis), name_{name}, logger_{logger} {
: OrtEp(), ApiPtrs(apis), name_{name}, logger_{logger} {
// Initialize the execution provider.
auto status = ort_api.Logger_LogMessage(&logger_,
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
Expand Down Expand Up @@ -323,7 +323,7 @@ struct ExampleEp : OrtEp, ApiPtrs {
};

//
// Implementation of ExampleNodeComuteInfo
// Implementation of ExampleNodeComputeInfo
//
ExampleNodeComputeInfo::ExampleNodeComputeInfo(ExampleEp& ep) : ep(ep) {
ort_version_supported = ORT_API_VERSION;
Expand Down
113 changes: 113 additions & 0 deletions onnxruntime/test/framework/ep_plugin_provider_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/session/ep_plugin_provider_interfaces.h"

#include "gsl/gsl"
#include "gtest/gtest.h"

#include "core/session/onnxruntime_c_api.h"
#include "test/util/include/asserts.h"

namespace onnxruntime::test {

// Helper class to access public ORT APIs.
struct ApiPtrs {
ApiPtrs() : ort_api{::OrtGetApiBase()->GetApi(ORT_API_VERSION)},
ep_api{ort_api->GetEpApi()} {
}

const gsl::not_null<const ::OrtApi*> ort_api;
const gsl::not_null<const ::OrtEpApi*> ep_api;
};

// Normally, a plugin EP would be implemented in a separate library.
// The `test_plugin_ep` namespace contains a local implementation intended for unit testing.
namespace test_plugin_ep {

struct TestOrtEp : ::OrtEp, ApiPtrs {
TestOrtEp() : ::OrtEp{}, ApiPtrs{} {
ort_version_supported = ORT_API_VERSION;

GetName = GetNameImpl;

// Individual tests should fill out the other function pointers as needed.
}

static const char* ORT_API_CALL GetNameImpl(const OrtEp* /*this_ptr*/) {
constexpr const char* ep_name = "TestOrtEp";
return ep_name;
}
};

// This factory doesn't do anything other than implement ReleaseEp().
// It is only used to create the UniqueOrtEp that is required by PluginExecutionProvider.
struct TestOrtEpFactory : ::OrtEpFactory {
TestOrtEpFactory() : ::OrtEpFactory{} {
ort_version_supported = ORT_API_VERSION;
ReleaseEp = ReleaseEpImpl;
}

static void ORT_API_CALL ReleaseEpImpl(::OrtEpFactory* /*this_ptr*/, OrtEp* ep) {
delete static_cast<TestOrtEp*>(ep);
}
};

static TestOrtEpFactory g_test_ort_ep_factory{};

struct MakeTestOrtEpResult {
std::unique_ptr<IExecutionProvider> ep; // the IExecutionProvider wrapping the TestOrtEp
gsl::not_null<TestOrtEp*> ort_ep; // the wrapped TestOrtEp, owned by `ep`
};

// Creates an IExecutionProvider that wraps a TestOrtEp.
// The TestOrtEp is also exposed so that tests can manipulate its function pointers directly.
MakeTestOrtEpResult MakeTestOrtEp() {
auto ort_ep_raw = std::make_unique<TestOrtEp>().release();
auto ort_ep = UniqueOrtEp(ort_ep_raw, OrtEpDeleter{g_test_ort_ep_factory});
auto ep = std::make_unique<PluginExecutionProvider>(std::move(ort_ep));
auto result = MakeTestOrtEpResult{std::move(ep), ort_ep_raw};
return result;
}

} // namespace test_plugin_ep

TEST(PluginExecutionProviderTest, GetPreferredLayout) {
auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp();

{
ort_ep->GetPreferredDataLayout = nullptr;
ASSERT_EQ(ep->GetPreferredLayout(), DataLayout::NCHW);
}

{
auto prefer_nhwc_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* {
*preferred_data_layout = OrtEpDataLayout::NCHW;
return nullptr;
};
ort_ep->GetPreferredDataLayout = prefer_nhwc_fn;
ASSERT_EQ(ep->GetPreferredLayout(), DataLayout::NCHW);
}

#if !defined(ORT_NO_EXCEPTIONS)
{
auto invalid_layout_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* {
*preferred_data_layout = static_cast<OrtEpDataLayout>(-1);
return nullptr;
};
ort_ep->GetPreferredDataLayout = invalid_layout_fn;
ASSERT_THROW(ep->GetPreferredLayout(), OnnxRuntimeException);
}

{
auto failing_fn = [](OrtEp* this_ptr, OrtEpDataLayout* /*preferred_data_layout*/) -> ::OrtStatus* {
auto* test_ort_ep = static_cast<test_plugin_ep::TestOrtEp*>(this_ptr);
return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "I can't decide what data layout I prefer.");
};
ort_ep->GetPreferredDataLayout = failing_fn;
ASSERT_THROW(ep->GetPreferredLayout(), OnnxRuntimeException);
}
#endif // !defined(ORT_NO_EXCEPTIONS)
}

} // namespace onnxruntime::test
Loading