Skip to content

Commit 11401c1

Browse files
WeldonWangwangWovchenaCopilot
authored
CB: Hetero pipeline parallel support (#2227)
depend on: openvinotoolkit/openvino#30371 Tickets:CVS-164805 --------- Co-authored-by: Vladimir Zlobin <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 9a63488 commit 11401c1

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

src/cpp/src/continuous_batching/cache_manager.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,18 @@ class CacheManager {
3838
// extract information about inference device
3939
ov::CompiledModel compiled_model = request.get_compiled_model();
4040
std::vector<std::string> execution_devices = compiled_model.get_property(ov::execution_devices);
41-
OPENVINO_ASSERT(execution_devices.size() == 1, "Contituous batching: execution device is expected to be CPU or GPU, but got ", execution_devices.size(), " devices");
41+
const bool all_gpu_device =
42+
std::all_of(execution_devices.begin(), execution_devices.end(), [&](const std::string& device) {
43+
return device.find("GPU") != std::string::npos;
44+
});
45+
OPENVINO_ASSERT(all_gpu_device || execution_devices.size() == 1,
46+
"Continuous batching: execution device is expected to be single CPU / single GPU / multi GPUs");
4247
m_device = execution_devices[0];
43-
4448
// set block_size depending on device
4549
const size_t cpu_block_size = 32, gpu_block_size = 16;
46-
const bool is_gpu = m_device.find("GPU") != std::string::npos;
47-
m_block_size = is_gpu ? gpu_block_size : cpu_block_size;
50+
m_block_size = all_gpu_device ? gpu_block_size : cpu_block_size;
4851

49-
if (is_gpu) {
52+
if (all_gpu_device) {
5053
m_context = m_request.get_compiled_model().get_context();
5154
}
5255
// extract information about KV cache precisions and shapes

src/cpp/src/continuous_batching/pipeline_impl.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,12 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline(
104104

105105
ov::CompiledModel compiled_model = utils::singleton_core().compile_model(model, device, *filtered_properties);
106106
std::vector<std::string> execution_devices = compiled_model.get_property(ov::execution_devices);
107-
OPENVINO_ASSERT(execution_devices.size() == 1, "Contituous batching: execution device is expected to be CPU or GPU, but got ", execution_devices.size(), " devices");
107+
const bool all_gpu_device =
108+
std::all_of(execution_devices.begin(), execution_devices.end(), [&](const std::string& device) {
109+
return device.find("GPU") != std::string::npos;
110+
});
111+
OPENVINO_ASSERT(all_gpu_device || execution_devices.size() == 1,
112+
"Continuous batching: execution device is expected to be single CPU / single GPU / multi GPUs");
108113
const std::string execution_device = execution_devices[0];
109114

110115
ov::genai::utils::print_compiled_model_properties(compiled_model, "LLM with Paged Attention");

0 commit comments

Comments
 (0)