-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[ORT EP API] Add some additional ORT EP APIs #25127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 7 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
e694e07
add EP API GetPreferredDataLayout()
edgchen1 5d8601b
add EP API SetEpDynamicOptions()
edgchen1 e45d846
set SetDynamicOptions ptr to null
edgchen1 0266ceb
add EP APIs OnRunStart() and OnRunEnd()
edgchen1 9e3b3f4
fix typo
edgchen1 ea42c1d
add ep_plugin_provider_test.cc
edgchen1 f4921a0
clean up tests
edgchen1 e0226d7
add GetRunConfigEntry C API
edgchen1 a7bb167
remove NCHWC data layout - it did not seem to be used. add prefix for…
edgchen1 b5f2adb
Add remark about where dynamic options come from.
edgchen1 b08db16
Merge remote-tracking branch 'origin/main' into edgchen1/ep_api_updates
edgchen1 d5d0dbf
add PluginExecutionProvider ctor arg for session options
edgchen1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/session/ep_plugin_provider_interfaces.h" | ||
|
||
#include "gsl/gsl" | ||
#include "gtest/gtest.h" | ||
|
||
#include "core/session/onnxruntime_c_api.h" | ||
#include "test/util/include/asserts.h" | ||
|
||
namespace onnxruntime::test { | ||
|
||
// Helper class to access public ORT APIs. | ||
struct ApiPtrs { | ||
ApiPtrs() : ort_api{::OrtGetApiBase()->GetApi(ORT_API_VERSION)}, | ||
ep_api{ort_api->GetEpApi()} { | ||
} | ||
|
||
const gsl::not_null<const ::OrtApi*> ort_api; | ||
const gsl::not_null<const ::OrtEpApi*> ep_api; | ||
}; | ||
|
||
// Normally, a plugin EP would be implemented in a separate library. | ||
// The `test_plugin_ep` namespace contains a local implementation intended for unit testing. | ||
namespace test_plugin_ep { | ||
|
||
struct TestOrtEp : ::OrtEp, ApiPtrs { | ||
TestOrtEp() : ::OrtEp{}, ApiPtrs{} { | ||
ort_version_supported = ORT_API_VERSION; | ||
|
||
GetName = GetNameImpl; | ||
|
||
// Individual tests should fill out the other function pointers as needed. | ||
} | ||
|
||
static const char* ORT_API_CALL GetNameImpl(const OrtEp* /*this_ptr*/) { | ||
constexpr const char* ep_name = "TestOrtEp"; | ||
return ep_name; | ||
} | ||
}; | ||
|
||
// This factory doesn't do anything other than implement ReleaseEp(). | ||
// It is only used to create the UniqueOrtEp that is required by PluginExecutionProvider. | ||
struct TestOrtEpFactory : ::OrtEpFactory { | ||
TestOrtEpFactory() : ::OrtEpFactory{} { | ||
ort_version_supported = ORT_API_VERSION; | ||
ReleaseEp = ReleaseEpImpl; | ||
} | ||
|
||
static void ORT_API_CALL ReleaseEpImpl(::OrtEpFactory* /*this_ptr*/, OrtEp* ep) { | ||
delete static_cast<TestOrtEp*>(ep); | ||
} | ||
}; | ||
|
||
static TestOrtEpFactory g_test_ort_ep_factory{}; | ||
|
||
struct MakeTestOrtEpResult { | ||
std::unique_ptr<IExecutionProvider> ep; // the IExecutionProvider wrapping the TestOrtEp | ||
gsl::not_null<TestOrtEp*> ort_ep; // the wrapped TestOrtEp, owned by `ep` | ||
}; | ||
|
||
// Creates an IExecutionProvider that wraps a TestOrtEp. | ||
// The TestOrtEp is also exposed so that tests can manipulate its function pointers directly. | ||
MakeTestOrtEpResult MakeTestOrtEp() { | ||
auto ort_ep_raw = std::make_unique<TestOrtEp>().release(); | ||
auto ort_ep = UniqueOrtEp(ort_ep_raw, OrtEpDeleter{g_test_ort_ep_factory}); | ||
auto ep = std::make_unique<PluginExecutionProvider>(std::move(ort_ep)); | ||
auto result = MakeTestOrtEpResult{std::move(ep), ort_ep_raw}; | ||
return result; | ||
} | ||
|
||
} // namespace test_plugin_ep | ||
|
||
TEST(PluginExecutionProviderTest, GetPreferredLayout) { | ||
auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(); | ||
|
||
{ | ||
ort_ep->GetPreferredDataLayout = nullptr; | ||
ASSERT_EQ(ep->GetPreferredLayout(), DataLayout::NCHW); | ||
} | ||
|
||
{ | ||
auto prefer_nhwc_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { | ||
*preferred_data_layout = OrtEpDataLayout::NCHW; | ||
return nullptr; | ||
}; | ||
ort_ep->GetPreferredDataLayout = prefer_nhwc_fn; | ||
ASSERT_EQ(ep->GetPreferredLayout(), DataLayout::NCHW); | ||
} | ||
|
||
#if !defined(ORT_NO_EXCEPTIONS) | ||
{ | ||
auto invalid_layout_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { | ||
*preferred_data_layout = static_cast<OrtEpDataLayout>(-1); | ||
return nullptr; | ||
}; | ||
ort_ep->GetPreferredDataLayout = invalid_layout_fn; | ||
ASSERT_THROW(ep->GetPreferredLayout(), OnnxRuntimeException); | ||
} | ||
|
||
{ | ||
auto failing_fn = [](OrtEp* this_ptr, OrtEpDataLayout* /*preferred_data_layout*/) -> ::OrtStatus* { | ||
auto* test_ort_ep = static_cast<test_plugin_ep::TestOrtEp*>(this_ptr); | ||
return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "I can't decide what data layout I prefer."); | ||
}; | ||
ort_ep->GetPreferredDataLayout = failing_fn; | ||
ASSERT_THROW(ep->GetPreferredLayout(), OnnxRuntimeException); | ||
} | ||
#endif // !defined(ORT_NO_EXCEPTIONS) | ||
} | ||
|
||
} // namespace onnxruntime::test |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.