Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,23 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord
return Status::OK();
}

Status QnnBackendManager::SetContextPriority(ContextPriority context_priority) {
QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority, context_priority_config));

QnnContext_Config_t* configs[] = {&context_priority_config, nullptr};
for (const auto& context_handle : contexts_) {
auto result = qnn_interface_.contextSetConfig(context_handle, (const QnnContext_Config_t**)configs);
ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to set context priority for context handle: ", context_handle);
}

return Status::OK();
}

Status QnnBackendManager::ResetContextPriority() {
return SetContextPriority(context_priority_);
}

Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) {
if (true == context_created_) {
LOGS_DEFAULT(INFO) << "Context created already.";
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
// For each node name, a mapping to the context handle will be created
void ProcessContextFromBinListAsync(Qnn_ContextHandle_t handle, void* notifyParam);

// Sets the context priority to the given value, if valid
Status SetContextPriority(ContextPriority context_priority);
// Resets the context priority to the session default as defined by context_priority_
Status ResetContextPriority();

private:
Status LoadBackend();

Expand Down
34 changes: 34 additions & 0 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1545,4 +1545,38 @@
return default_device_;
}

Status QNNExecutionProvider::SetEpDynamicOptions(gsl::span<const char* const> keys,
gsl::span<const char* const> values) {
if (keys.size() != values.size()) {
LOGS_DEFAULT(ERROR) << "SetEpDynamicOptions: number of keys (" << keys.size()
<< ") does not equal number of values (" << values.size() << ").";
}
auto key_it = keys.begin();
auto value_it = values.begin();

while (key_it != keys.end() && value_it != values.end()) {
std::string key(*key_it);
std::string value(*value_it);

Check warning on line 1559 in onnxruntime/core/providers/qnn/qnn_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/qnn_execution_provider.cc:1559: Add #include <string> for string [build/include_what_you_use] [4]

if (key == kOrtEpDynamicOptionsWorkloadType) {
if (value == "Default") {
ORT_RETURN_IF_ERROR(qnn_backend_manager_->ResetContextPriority());
} else if (value == "Efficient") {
ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetContextPriority(qnn::ContextPriority::LOW));
} else {
LOGS_DEFAULT(ERROR) << "Invalid EP Workload Type: " << value;
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid EP Workload Type.");
}
} else {
LOGS_DEFAULT(ERROR) << "EP Dynamic Option \"" << key << "\" is not currently supported.";
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported EP Dynamic Option");
}

key_it++;
value_it++;
}

return Status::OK();
}

} // namespace onnxruntime
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class QNNExecutionProvider : public IExecutionProvider {

OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;

Status SetEpDynamicOptions(gsl::span<const char* const> keys,
gsl::span<const char* const> value) override;

private:
std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewer,
const std::unordered_map<const Node*, const NodeUnit*>& node_unit_map,
Expand Down
69 changes: 67 additions & 2 deletions onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,6 @@ static void DumpModelWithSharedCtx(ProviderOptions provider_options,
Ort::Session session2(*ort_env, ToPathString(onnx_model_path2).c_str(), so);
}

#if defined(__aarch64__) || defined(_M_ARM64)
static void GetModelInputNames(const std::string& model_path,
std::vector<std::string>& input_names,
std::vector<std::string>& output_names,
Expand All @@ -1669,7 +1668,6 @@ static void GetModelInputNames(const std::string& model_path,
output_names.push_back(output->Name());
}
}
#endif

// 1. Create 2 QDQ models
// 2. Initialize 2 Ort sessions which share the same QNN EP from these 2 QDQ models
Expand Down Expand Up @@ -1994,6 +1992,73 @@ TEST_F(QnnHTPBackendTests, LoadFromArrayWithQnnEpContextGenPathValidation) {
});
}
}

TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) {
ProviderOptions provider_options;
provider_options["backend_type"] = "htp";
provider_options["offload_graph_io_quantization"] = "0";

Ort::SessionOptions so;
so.AppendExecutionProvider("QNN", provider_options);
so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE);

Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx"), so);

std::vector<std::string> input_names;
std::vector<std::string> output_names;
GetModelInputNames("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx", input_names, output_names,
DefaultLoggingManager().DefaultLogger());

// Run sessions
// prepare input
std::vector<int64_t> input_dim{3, 4};
std::vector<float> input_value(3 * 4, 0.0f);
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> input_names_c;
for (size_t i = 0; i < input_names.size(); ++i) {
auto input_tensor = Ort::Value::CreateTensor(info, input_value.data(), input_value.size(),
input_dim.data(), input_dim.size());
ort_inputs.push_back(std::move(input_tensor));
input_names_c.push_back(input_names[i].c_str());
}
std::vector<const char*> output_names_c;
for (size_t i = 0; i < output_names.size(); ++i) {
output_names_c.push_back(output_names[i].c_str());
}

auto ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(),
output_names_c.data(), 1);

const char* const workload_type[] = {"ep.dynamic.workload_type"};
const char* const efficient_type[] = {"Efficient"};
const char* const default_type[] = {"Default"};

// Test Efficient & Default options
session.SetEpDynamicOptions(workload_type, efficient_type, 1);
ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(),
output_names_c.data(), 1);

session.SetEpDynamicOptions(workload_type, default_type, 1);
ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(),
output_names_c.data(), 1);

// Test invalid EP dynamic option and invalid workload type
const char* const dne[] = {"DNE"};
try {
session.SetEpDynamicOptions(workload_type, dne, 1);
FAIL() << "Expected exception to be thrown for workload type DNE but was set successfully";
} catch (const std::exception& e) {
EXPECT_STREQ("Invalid EP Workload Type.", e.what());
}

try {
session.SetEpDynamicOptions(dne, efficient_type, 1);
FAIL() << "Expected exception to be thrown for dynamic option DNE but was set successfully";
} catch (const std::exception& e) {
EXPECT_STREQ("Unsupported EP Dynamic Option", e.what());
}
}
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)

} // namespace test
Expand Down
Loading