From 8a24f729da82f42ac683ce3979b1df8c8328d8f0 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Tue, 1 Jul 2025 14:58:59 +0200 Subject: [PATCH 01/10] add performance metrics for StructuredOutput generation --- .../include/openvino/genai/perf_metrics.hpp | 9 ++ .../src/continuous_batching/pipeline_impl.cpp | 20 ++++ src/cpp/src/perf_metrics.cpp | 14 +++ src/cpp/src/sampling/logit_processor.hpp | 11 +++ src/cpp/src/sampling/logit_transformers.hpp | 5 + src/cpp/src/sampling/sampler.cpp | 13 +++ src/cpp/src/sampling/sampler.hpp | 1 + .../structured_output_controller.hpp | 1 + .../structured_output/xgrammar_backend.cpp | 16 +++- .../structured_output/xgrammar_backend.hpp | 10 +- .../openvino_genai/py_openvino_genai.pyi | 10 ++ src/python/py_perf_metrics.cpp | 8 ++ tests/python_tests/test_llm_pipeline.py | 95 +++++++++++++++++++ 13 files changed, 211 insertions(+), 2 deletions(-) diff --git a/src/cpp/include/openvino/genai/perf_metrics.hpp b/src/cpp/include/openvino/genai/perf_metrics.hpp index 1ea0fb55f9..9ba61dcc9f 100644 --- a/src/cpp/include/openvino/genai/perf_metrics.hpp +++ b/src/cpp/include/openvino/genai/perf_metrics.hpp @@ -39,6 +39,9 @@ struct OPENVINO_GENAI_EXPORTS RawPerfMetrics { std::vector m_batch_sizes; std::vector m_durations; std::vector m_inference_durations; + + std::vector m_grammar_init_time; + std::vector m_grammar_compile_time; }; /** @@ -103,6 +106,9 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics { MeanStdPair ipot; // Inference time (in ms) per output token. MeanStdPair throughput; // Tokens per second. + MeanStdPair grammar_compiler_init_time; // Time to initialize grammar compiler in ms. + MeanStdPair grammar_compile_time; // Time to compile grammar in ms. + MeanStdPair generate_duration; MeanStdPair inference_duration; MeanStdPair tokenization_duration = {-1.0f, -1.0f}; @@ -118,6 +124,9 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics { MeanStdPair get_tpot(); // Time (in ms) per output token (TPOT). MeanStdPair get_ipot(); // Inference time (in ms) per output token. MeanStdPair get_throughput(); // Tokens per second. + + MeanStdPair get_grammar_compiler_init_time(); + MeanStdPair get_grammar_compile_time(); MeanStdPair get_inference_duration(); // in ms MeanStdPair get_generate_duration(); // in ms diff --git a/src/cpp/src/continuous_batching/pipeline_impl.cpp b/src/cpp/src/continuous_batching/pipeline_impl.cpp index 9f585d1a5d..d95d107053 100644 --- a/src/cpp/src/continuous_batching/pipeline_impl.cpp +++ b/src/cpp/src/continuous_batching/pipeline_impl.cpp @@ -422,10 +422,30 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorstart(); + // TODO: It's done because after all the steps are finished logit_transformers are destroyed + // Find a more clever solution. + bool first_step = true; + while (has_non_finished_requests()) { try { const auto infer_start = std::chrono::steady_clock::now(); step(); + if (first_step) { + auto times = m_sampler->get_structured_output_times(); + // each sequence id will be added as a separate compilation time. + // but compiler initialization will be done only once. + for (const auto& t: times) { + if (raw_perf_counters.m_grammar_init_time.empty()) { + raw_perf_counters.m_grammar_init_time.emplace_back(t.second.first); + } else { + OPENVINO_ASSERT(raw_perf_counters.m_grammar_init_time.back().count() == t.second.first, + "Grammar initialization time should be the same for all sequences in the first step"); + } + raw_perf_counters.m_grammar_compile_time.emplace_back(t.second.second); + } + first_step = false; + } + // During prefill step (or steps if max_batch_size < prompt_len) we don't generate new tokens, // but still inference took place, so we need to add this time to the total inference duration. raw_perf_counters.m_inference_durations[0] += MicroSeconds(m_pipeline_metrics.inference_duration); diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp index 74f281a942..6b92449d01 100644 --- a/src/cpp/src/perf_metrics.cpp +++ b/src/cpp/src/perf_metrics.cpp @@ -84,6 +84,17 @@ MeanStdPair PerfMetrics::get_inference_duration() { return inference_duration; } +MeanStdPair PerfMetrics::get_grammar_compiler_init_time() { + evaluate_statistics(); + return grammar_compiler_init_time; +} + +MeanStdPair PerfMetrics::get_grammar_compile_time() { + evaluate_statistics(); + return grammar_compile_time; +} + + float PerfMetrics::get_microsec(std::chrono::steady_clock::duration duration) { return std::chrono::duration_cast(duration).count(); } @@ -120,6 +131,9 @@ void PerfMetrics::evaluate_statistics(std::optional start_time) { ipot = calc_mean_and_std(raw_metrics.m_token_infer_durations); ttft = calc_mean_and_std(raw_metrics.m_times_to_first_token); + grammar_compiler_init_time = calc_mean_and_std(raw_metrics.m_grammar_init_time); + grammar_compile_time = calc_mean_and_std(raw_metrics.m_grammar_compile_time); + generate_duration = calc_mean_and_std(raw_metrics.generate_durations); tokenization_duration = calc_mean_and_std(raw_metrics.tokenization_durations); detokenization_duration = calc_mean_and_std(raw_metrics.detokenization_durations); diff --git a/src/cpp/src/sampling/logit_processor.hpp b/src/cpp/src/sampling/logit_processor.hpp index e1af044e2d..ba97c5301d 100644 --- a/src/cpp/src/sampling/logit_processor.hpp +++ b/src/cpp/src/sampling/logit_processor.hpp @@ -122,6 +122,17 @@ class LogitProcessor { m_unique_generated_token_ids->at(token_id)--; } + std::pair get_structured_output_times() const { + for (const auto& transformer: m_logit_transformers) { + auto transformer_with_perf = std::dynamic_pointer_cast(transformer); + if (transformer_with_perf) { + return transformer_with_perf->get_structured_output_times(); + } + // TODO: handle multiple transformers + } + return {0.f, 0.f}; + } + }; } // namespace ov::genai diff --git a/src/cpp/src/sampling/logit_transformers.hpp b/src/cpp/src/sampling/logit_transformers.hpp index 253fd507e5..437db192a2 100644 --- a/src/cpp/src/sampling/logit_transformers.hpp +++ b/src/cpp/src/sampling/logit_transformers.hpp @@ -68,6 +68,11 @@ class IStatefulLogitTransformer: public ILogitTransformer { virtual void accept_tokens(const TokenIds& input_ids) = 0; }; +class LogitTransformersMetrics { +public: + virtual std::pair get_structured_output_times() const = 0; +}; + class TopPFilter : public ILogitTransformer { public: diff --git a/src/cpp/src/sampling/sampler.cpp b/src/cpp/src/sampling/sampler.cpp index 4801c6e67d..5d365fcb0b 100644 --- a/src/cpp/src/sampling/sampler.cpp +++ b/src/cpp/src/sampling/sampler.cpp @@ -241,6 +241,19 @@ std::map Sampler::GroupBeamSearcher::get_beam_idxs() { return next_beams; } +std::map> Sampler::get_structured_output_times() { + // {init_grammar_time, grammar_compilation_time} + std::pair times{0.0f, 0.0f}; + + // { request_id, times } + std::map> structured_output_times; + + for (const auto& [id, logit_processor] : m_logit_processors) { + structured_output_times[id] = logit_processor.get_structured_output_times(); + } + return structured_output_times; +} + void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output, const std::pair>& stop_strings) { diff --git a/src/cpp/src/sampling/sampler.hpp b/src/cpp/src/sampling/sampler.hpp index 2fb8648aac..c1d80fcfcf 100644 --- a/src/cpp/src/sampling/sampler.hpp +++ b/src/cpp/src/sampling/sampler.hpp @@ -123,6 +123,7 @@ class Sampler { void create_logit_processor(uint64_t request_id, const GenerationConfig& sampling_parameters, const TokenIds& prompt); std::map get_beam_idxs(SequenceGroup::CPtr sequence_group); + std::map> get_structured_output_times(); }; class Sampler::GroupBeamSearcher { diff --git a/src/cpp/src/sampling/structured_output/structured_output_controller.hpp b/src/cpp/src/sampling/structured_output/structured_output_controller.hpp index da53361a10..04e24fce79 100644 --- a/src/cpp/src/sampling/structured_output/structured_output_controller.hpp +++ b/src/cpp/src/sampling/structured_output/structured_output_controller.hpp @@ -80,6 +80,7 @@ class StructuredOutputController { std::unordered_map> m_impls; const ov::genai::Tokenizer& m_tokenizer; std::optional m_vocab_size; + std::mutex m_mutex; }; } // namespace genai diff --git a/src/cpp/src/sampling/structured_output/xgrammar_backend.cpp b/src/cpp/src/sampling/structured_output/xgrammar_backend.cpp index 7a7ef3be96..8a693eec3f 100644 --- a/src/cpp/src/sampling/structured_output/xgrammar_backend.cpp +++ b/src/cpp/src/sampling/structured_output/xgrammar_backend.cpp @@ -8,6 +8,7 @@ namespace ov { namespace genai { XGrammarStructuredOutput::XGrammarStructuredOutput(const Tokenizer& tokenizer, std::optional vocab_size) { + const auto start = std::chrono::steady_clock::now(); auto vocab_vector = tokenizer.get_vocab_vector(); if (!vocab_size.has_value()) { vocab_size = vocab_vector.size(); @@ -20,7 +21,10 @@ XGrammarStructuredOutput::XGrammarStructuredOutput(const Tokenizer& tokenizer, s std::vector{static_cast(tokenizer.get_eos_token_id())}, true ); + m_grammar_compiler = std::make_unique(std::move(tokenizer_info)); + const auto end = std::chrono::steady_clock::now(); + m_init_grammar_compiler_time = std::chrono::duration_cast(end - start).count(); } std::shared_ptr @@ -31,6 +35,8 @@ XGrammarStructuredOutput::get_logits_transformer(const GenerationConfig& samplin auto& guided_gen_config = *sampling_parameters.structured_output_config; guided_gen_config.validate(); + auto start = std::chrono::steady_clock::now(); + xgrammar::Grammar grammar; if (guided_gen_config.json_schema.has_value()) { grammar = xgrammar::Grammar::FromJSONSchema(*guided_gen_config.json_schema); @@ -41,9 +47,17 @@ XGrammarStructuredOutput::get_logits_transformer(const GenerationConfig& samplin } auto compiled_grammar = m_grammar_compiler->CompileGrammar(grammar); + std::vector override_stop_tokens(sampling_parameters.stop_token_ids.begin(), sampling_parameters.stop_token_ids.end()); + auto logit_transformer = std::make_shared(std::move(compiled_grammar), override_stop_tokens); - return std::make_shared(std::move(compiled_grammar), override_stop_tokens); + auto stop = std::chrono::steady_clock::now(); + + // In order to keep all the metrics in one place we store grammaer compiler initilization time + // in the logit_transofmer as well, even though grammar compiler is initialzed prior to + logit_transformer->m_init_grammar_compiler_time = m_init_grammar_compiler_time; + logit_transformer->m_grammar_compile_time = std::chrono::duration_cast(stop - start).count(); + return logit_transformer; } namespace LogitTransformers { diff --git a/src/cpp/src/sampling/structured_output/xgrammar_backend.hpp b/src/cpp/src/sampling/structured_output/xgrammar_backend.hpp index e326c6880c..cb1b544af1 100644 --- a/src/cpp/src/sampling/structured_output/xgrammar_backend.hpp +++ b/src/cpp/src/sampling/structured_output/xgrammar_backend.hpp @@ -24,7 +24,7 @@ namespace LogitTransformers { * which applies grammar constraints to the logits each time a new token is generated, and * accepts tokens to update the internal state of the grammar matcher. */ -class XGrammarLogitsTransformer : public IStatefulLogitTransformer { +class XGrammarLogitsTransformer : public IStatefulLogitTransformer, public LogitTransformersMetrics { public: XGrammarLogitsTransformer( const xgrammar::CompiledGrammar& compiled_grammar, @@ -36,6 +36,13 @@ class XGrammarLogitsTransformer : public IStatefulLogitTransformer { void accept_tokens(const TokenIds& input_ids) override; void apply(Logits& logits) override; + + std::pair get_structured_output_times() const override { + return {m_init_grammar_compiler_time, m_grammar_compile_time}; + }; + + float m_init_grammar_compiler_time; + float m_grammar_compile_time; protected: xgrammar::GrammarMatcher m_grammar_matcher; @@ -80,6 +87,7 @@ class XGrammarStructuredOutput : public IStructuredOutputImpl { std::shared_ptr get_logits_transformer(const GenerationConfig& sampling_parameters) override; private: std::unique_ptr m_grammar_compiler; + float m_init_grammar_compiler_time; }; diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index 0faac1a970..3547592045 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -1343,6 +1343,10 @@ class PerfMetrics: ... def get_generate_duration(self) -> MeanStdPair: ... + def get_grammar_compile_time(self) -> MeanStdPair: + ... + def get_grammar_compiler_init_time(self) -> MeanStdPair: + ... def get_inference_duration(self) -> MeanStdPair: ... def get_ipot(self) -> MeanStdPair: @@ -1478,6 +1482,12 @@ class RawPerfMetrics: def m_durations(self) -> list[float]: ... @property + def m_grammar_compile_time(self) -> list[float]: + ... + @property + def m_grammar_init_time(self) -> list[float]: + ... + @property def m_new_token_times(self) -> list[float]: ... @property diff --git a/src/python/py_perf_metrics.cpp b/src/python/py_perf_metrics.cpp index 6b2019aa97..3f18de399f 100644 --- a/src/python/py_perf_metrics.cpp +++ b/src/python/py_perf_metrics.cpp @@ -146,6 +146,12 @@ void init_perf_metrics(py::module_& m) { }) .def_property_readonly("inference_durations", [](const RawPerfMetrics &rw) { return pyutils::get_ms(rw, &RawPerfMetrics::m_inference_durations); + }) + .def_property_readonly("m_grammar_init_time", [](const RawPerfMetrics &rw) { + return pyutils::get_ms(rw, &RawPerfMetrics::m_grammar_init_time); + }) + .def_property_readonly("m_grammar_compile_time", [](const RawPerfMetrics &rw) { + return pyutils::get_ms(rw, &RawPerfMetrics::m_grammar_compile_time); }); py::class_(m, "MeanStdPair") @@ -159,6 +165,8 @@ void init_perf_metrics(py::module_& m) { py::class_(m, "PerfMetrics", perf_metrics_docstring) .def(py::init<>()) .def("get_load_time", &PerfMetrics::get_load_time) + .def("get_grammar_compiler_init_time", &PerfMetrics::get_grammar_compiler_init_time) + .def("get_grammar_compile_time", &PerfMetrics::get_grammar_compile_time) .def("get_num_generated_tokens", &PerfMetrics::get_num_generated_tokens) .def("get_num_input_tokens", &PerfMetrics::get_num_input_tokens) .def("get_ttft", &PerfMetrics::get_ttft) diff --git a/tests/python_tests/test_llm_pipeline.py b/tests/python_tests/test_llm_pipeline.py index 23ef545d49..63227377bd 100644 --- a/tests/python_tests/test_llm_pipeline.py +++ b/tests/python_tests/test_llm_pipeline.py @@ -778,6 +778,101 @@ def test_perf_metrics(generation_config, prompt): assert len(raw_metrics.m_batch_sizes) > 0 assert len(raw_metrics.m_durations) > 0 +import json +from typing import Literal +from pydantic import BaseModel, Field + + +class Person(BaseModel): + name: str = Field(pattern=r"^[A-Z][a-z]{1,20}$") + surname: str = Field(pattern=r"^[A-Z][a-z]{1,20}$") + age: int + city: Literal["Dublin", "Dubai", "Munich"] + +test_cases = [ + (dict(max_new_tokens=20, structured_output_config=ov_genai.StructuredOutputConfig(json_schema=json.dumps(Person.model_json_schema()))), + 'Generate json of a person'), +] +@pytest.mark.parametrize("generation_config,prompt", test_cases) +@pytest.mark.precommit +@pytest.mark.nightly +def test_perf_metrics_with_structured_output(generation_config, prompt): + import time + start_time = time.perf_counter() + model_id = 'katuni4ka/tiny-random-gemma2' + perf_metrics = run_perf_metrics_collection(model_id, generation_config, prompt) + total_time = (time.perf_counter() - start_time) * 1000 + + # Check that load time is adequate. + load_time = perf_metrics.get_load_time() + assert load_time > 0 and load_time < total_time + + # Check that num input and generated tokens are adequate. + num_generated_tokens = perf_metrics.get_num_generated_tokens() + assert num_generated_tokens > 0 and num_generated_tokens <= generation_config['max_new_tokens'] + + num_input_tokens = perf_metrics.get_num_input_tokens() + assert num_input_tokens > 0 and num_input_tokens <= len(prompt) + + mean_ttft, std_ttft = perf_metrics.get_ttft() + assert (mean_ttft, std_ttft) == (perf_metrics.get_ttft().mean, perf_metrics.get_ttft().std) + assert mean_ttft > 0 and mean_ttft < 1000.0 + + raw_metrics = perf_metrics.raw_metrics + durations = np.array(raw_metrics.m_durations) / 1000 + # Check that prefill is not included in durations for TPOT calculation. + # For the very long prompt prefill is slow and TTFT is much larger than any other token generation duration. + assert np.all(mean_ttft > durations * 2) + + mean_tpot, std_tpot = perf_metrics.get_tpot() + assert (mean_tpot, std_tpot) == (perf_metrics.get_tpot().mean, perf_metrics.get_tpot().std) + assert mean_tpot > 0 and mean_ttft < 1000.0 + + mean_throughput, std_throughput = perf_metrics.get_throughput() + assert (mean_throughput, std_throughput) == (perf_metrics.get_throughput().mean, perf_metrics.get_throughput().std) + assert mean_throughput > 0 and mean_throughput < 20000.0 + + mean_gen_duration, std_gen_duration = perf_metrics.get_generate_duration() + assert (mean_gen_duration, std_gen_duration) == (perf_metrics.get_generate_duration().mean, perf_metrics.get_generate_duration().std) + assert mean_gen_duration > 0 and load_time + mean_gen_duration < total_time + assert std_gen_duration == 0 + + mean_tok_duration, std_tok_duration = perf_metrics.get_tokenization_duration() + assert (mean_tok_duration, std_tok_duration) == (perf_metrics.get_tokenization_duration().mean, perf_metrics.get_tokenization_duration().std) + assert mean_tok_duration > 0 and mean_tok_duration < mean_gen_duration + assert std_tok_duration == 0 + + mean_detok_duration, std_detok_duration = perf_metrics.get_detokenization_duration() + assert (mean_detok_duration, std_detok_duration) == (perf_metrics.get_detokenization_duration().mean, perf_metrics.get_detokenization_duration().std) + assert mean_detok_duration > 0 and mean_detok_duration < mean_gen_duration + assert std_detok_duration == 0 + + # assert that calculating statistics manually from the raw counters we get the same restults as from PerfMetrics + assert np.allclose(mean_tpot, np.mean(durations)) + assert np.allclose(std_tpot, np.std(durations)) + + raw_dur = np.array(raw_metrics.generate_durations) / 1000 + assert np.allclose(mean_gen_duration, np.mean(raw_dur)) + assert np.allclose(std_gen_duration, np.std(raw_dur)) + + raw_dur = np.array(raw_metrics.tokenization_durations) / 1000 + assert np.allclose(mean_tok_duration, np.mean(raw_dur)) + assert np.allclose(std_tok_duration, np.std(raw_dur)) + + raw_dur = np.array(raw_metrics.detokenization_durations) / 1000 + assert np.allclose(mean_detok_duration, np.mean(raw_dur)) + assert np.allclose(std_detok_duration, np.std(raw_dur)) + + assert len(raw_metrics.m_times_to_first_token) > 0 + assert len(raw_metrics.m_batch_sizes) > 0 + assert len(raw_metrics.m_durations) > 0 + + assert len(raw_metrics.m_grammar_init_time) == 1 + assert len(raw_metrics.m_grammar_compile_time) > 0 + + assert np.allclose(raw_metrics.m_grammar_init_time * 1000, perf_metrics.get_grammar_init_time().mean) + assert np.allclose(raw_metrics.m_grammar_compile_time * 1000, perf_metrics.get_grammar_compile_time().mean) + @pytest.mark.parametrize("pipeline_type", get_main_pipeline_types()) @pytest.mark.parametrize("stop_str", {True, False}) From 8b1e28e64891d02d121051477889f0bbf28cde76 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Thu, 3 Jul 2025 14:24:57 +0200 Subject: [PATCH 02/10] simplify getting compile times from structured_output_controller --- .../include/openvino/genai/perf_metrics.hpp | 11 +++++++ .../src/continuous_batching/pipeline_impl.cpp | 29 +++++-------------- src/cpp/src/generation_config.cpp | 3 ++ src/cpp/src/sampling/logit_processor.hpp | 11 ------- src/cpp/src/sampling/logit_transformers.hpp | 5 ---- src/cpp/src/sampling/sampler.cpp | 20 +++++++------ src/cpp/src/sampling/sampler.hpp | 3 +- .../structured_output_controller.cpp | 9 +++++- .../structured_output_controller.hpp | 5 ++-- .../structured_output/xgrammar_backend.cpp | 16 +--------- .../structured_output/xgrammar_backend.hpp | 10 +------ src/python/py_perf_metrics.cpp | 9 ++++++ src/python/py_utils.cpp | 2 ++ tests/python_tests/test_llm_pipeline.py | 8 +++-- 14 files changed, 64 insertions(+), 77 deletions(-) diff --git a/src/cpp/include/openvino/genai/perf_metrics.hpp b/src/cpp/include/openvino/genai/perf_metrics.hpp index 9ba61dcc9f..5097aa044f 100644 --- a/src/cpp/include/openvino/genai/perf_metrics.hpp +++ b/src/cpp/include/openvino/genai/perf_metrics.hpp @@ -52,6 +52,17 @@ struct OPENVINO_GENAI_EXPORTS MeanStdPair { float std; }; +/** +* @brief Structure to store list of durations in milliseconds. +*/ +struct OPENVINO_GENAI_EXPORTS DurationValues { + float min; + float max; + float mean; + float std; + std::vector values; +}; + /** * @brief Holds performance metrics for each generate call. * diff --git a/src/cpp/src/continuous_batching/pipeline_impl.cpp b/src/cpp/src/continuous_batching/pipeline_impl.cpp index d95d107053..7c86851dc2 100644 --- a/src/cpp/src/continuous_batching/pipeline_impl.cpp +++ b/src/cpp/src/continuous_batching/pipeline_impl.cpp @@ -421,31 +421,12 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorstart(); - - // TODO: It's done because after all the steps are finished logit_transformers are destroyed - // Find a more clever solution. - bool first_step = true; - + m_sampler->clear_structured_output_compile_times(); while (has_non_finished_requests()) { try { const auto infer_start = std::chrono::steady_clock::now(); step(); - if (first_step) { - auto times = m_sampler->get_structured_output_times(); - // each sequence id will be added as a separate compilation time. - // but compiler initialization will be done only once. - for (const auto& t: times) { - if (raw_perf_counters.m_grammar_init_time.empty()) { - raw_perf_counters.m_grammar_init_time.emplace_back(t.second.first); - } else { - OPENVINO_ASSERT(raw_perf_counters.m_grammar_init_time.back().count() == t.second.first, - "Grammar initialization time should be the same for all sequences in the first step"); - } - raw_perf_counters.m_grammar_compile_time.emplace_back(t.second.second); - } - first_step = false; - } - + // During prefill step (or steps if max_batch_size < prompt_len) we don't generate new tokens, // but still inference took place, so we need to add this time to the total inference duration. raw_perf_counters.m_inference_durations[0] += MicroSeconds(m_pipeline_metrics.inference_duration); @@ -464,6 +445,12 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorget_structured_output_times(); + raw_perf_counters.m_grammar_init_time.emplace_back(times.first); + for (const auto& t: times.second) { + raw_perf_counters.m_grammar_compile_time.emplace_back(t); + } + // waiting for competion of streaming streamer_ptr->end(); diff --git a/src/cpp/src/generation_config.cpp b/src/cpp/src/generation_config.cpp index 558804030a..b8c72b8790 100644 --- a/src/cpp/src/generation_config.cpp +++ b/src/cpp/src/generation_config.cpp @@ -149,6 +149,9 @@ void GenerationConfig::update_generation_config(const ov::AnyMap& properties) { read_anymap_param(properties, "assistant_confidence_threshold", assistant_confidence_threshold); read_anymap_param(properties, "num_assistant_tokens", num_assistant_tokens); read_anymap_param(properties, "max_ngram_size", max_ngram_size); + + // Structured Output generation + read_anymap_param(properties, "structured_output_config", structured_output_config); } StructuredOutputConfig::StructuredOutputConfig(const ov::AnyMap& properties) { diff --git a/src/cpp/src/sampling/logit_processor.hpp b/src/cpp/src/sampling/logit_processor.hpp index ba97c5301d..e1af044e2d 100644 --- a/src/cpp/src/sampling/logit_processor.hpp +++ b/src/cpp/src/sampling/logit_processor.hpp @@ -122,17 +122,6 @@ class LogitProcessor { m_unique_generated_token_ids->at(token_id)--; } - std::pair get_structured_output_times() const { - for (const auto& transformer: m_logit_transformers) { - auto transformer_with_perf = std::dynamic_pointer_cast(transformer); - if (transformer_with_perf) { - return transformer_with_perf->get_structured_output_times(); - } - // TODO: handle multiple transformers - } - return {0.f, 0.f}; - } - }; } // namespace ov::genai diff --git a/src/cpp/src/sampling/logit_transformers.hpp b/src/cpp/src/sampling/logit_transformers.hpp index 437db192a2..253fd507e5 100644 --- a/src/cpp/src/sampling/logit_transformers.hpp +++ b/src/cpp/src/sampling/logit_transformers.hpp @@ -68,11 +68,6 @@ class IStatefulLogitTransformer: public ILogitTransformer { virtual void accept_tokens(const TokenIds& input_ids) = 0; }; -class LogitTransformersMetrics { -public: - virtual std::pair get_structured_output_times() const = 0; -}; - class TopPFilter : public ILogitTransformer { public: diff --git a/src/cpp/src/sampling/sampler.cpp b/src/cpp/src/sampling/sampler.cpp index 5d365fcb0b..e54681b239 100644 --- a/src/cpp/src/sampling/sampler.cpp +++ b/src/cpp/src/sampling/sampler.cpp @@ -241,17 +241,19 @@ std::map Sampler::GroupBeamSearcher::get_beam_idxs() { return next_beams; } -std::map> Sampler::get_structured_output_times() { - // {init_grammar_time, grammar_compilation_time} - std::pair times{0.0f, 0.0f}; - - // { request_id, times } - std::map> structured_output_times; +std::pair> Sampler::get_structured_output_times() { + if (m_structured_output_controller) { + return {m_structured_output_controller->m_init_grammar_compiler_time, m_structured_output_controller->m_grammar_compile_times}; + } else { + // If compiled without structured output support, return empty times + return {0.0f, {}}; + } +} - for (const auto& [id, logit_processor] : m_logit_processors) { - structured_output_times[id] = logit_processor.get_structured_output_times(); +void Sampler::clear_structured_output_compile_times() { + if (m_structured_output_controller) { + m_structured_output_controller->m_grammar_compile_times.clear(); } - return structured_output_times; } void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, diff --git a/src/cpp/src/sampling/sampler.hpp b/src/cpp/src/sampling/sampler.hpp index c1d80fcfcf..1e64040b1d 100644 --- a/src/cpp/src/sampling/sampler.hpp +++ b/src/cpp/src/sampling/sampler.hpp @@ -123,7 +123,8 @@ class Sampler { void create_logit_processor(uint64_t request_id, const GenerationConfig& sampling_parameters, const TokenIds& prompt); std::map get_beam_idxs(SequenceGroup::CPtr sequence_group); - std::map> get_structured_output_times(); + std::pair> get_structured_output_times(); + void clear_structured_output_compile_times(); }; class Sampler::GroupBeamSearcher { diff --git a/src/cpp/src/sampling/structured_output/structured_output_controller.cpp b/src/cpp/src/sampling/structured_output/structured_output_controller.cpp index cdce0e054a..f440e63fbf 100644 --- a/src/cpp/src/sampling/structured_output/structured_output_controller.cpp +++ b/src/cpp/src/sampling/structured_output/structured_output_controller.cpp @@ -52,12 +52,19 @@ StructuredOutputController::get_logits_transformer(const ov::genai::GenerationCo } // Create the backend instance and store it + const auto start = std::chrono::steady_clock::now(); m_impls[backend_name] = factory_it->second(m_tokenizer, m_vocab_size); impl_it = m_impls.find(backend_name); + const auto end = std::chrono::steady_clock::now(); + m_init_grammar_compiler_time = std::chrono::duration_cast(end - start).count(); } // Use the instantiated backend - return impl_it->second->get_logits_transformer(sampling_parameters); + const auto start = std::chrono::steady_clock::now(); + auto res = impl_it->second->get_logits_transformer(sampling_parameters); + const auto end = std::chrono::steady_clock::now(); + m_grammar_compile_times.emplace_back(std::chrono::duration_cast(end - start).count()); + return res; } } // namespace genai diff --git a/src/cpp/src/sampling/structured_output/structured_output_controller.hpp b/src/cpp/src/sampling/structured_output/structured_output_controller.hpp index 04e24fce79..1ca1fcb45e 100644 --- a/src/cpp/src/sampling/structured_output/structured_output_controller.hpp +++ b/src/cpp/src/sampling/structured_output/structured_output_controller.hpp @@ -75,12 +75,13 @@ class StructuredOutputController { static void set_default_backend(const std::string& name); static std::string& get_default_backend_name(); static std::unordered_map& get_backend_registry(); - + + float m_init_grammar_compiler_time; + std::vector m_grammar_compile_times; private: std::unordered_map> m_impls; const ov::genai::Tokenizer& m_tokenizer; std::optional m_vocab_size; - std::mutex m_mutex; }; } // namespace genai diff --git a/src/cpp/src/sampling/structured_output/xgrammar_backend.cpp b/src/cpp/src/sampling/structured_output/xgrammar_backend.cpp index 8a693eec3f..7a7ef3be96 100644 --- a/src/cpp/src/sampling/structured_output/xgrammar_backend.cpp +++ b/src/cpp/src/sampling/structured_output/xgrammar_backend.cpp @@ -8,7 +8,6 @@ namespace ov { namespace genai { XGrammarStructuredOutput::XGrammarStructuredOutput(const Tokenizer& tokenizer, std::optional vocab_size) { - const auto start = std::chrono::steady_clock::now(); auto vocab_vector = tokenizer.get_vocab_vector(); if (!vocab_size.has_value()) { vocab_size = vocab_vector.size(); @@ -21,10 +20,7 @@ XGrammarStructuredOutput::XGrammarStructuredOutput(const Tokenizer& tokenizer, s std::vector{static_cast(tokenizer.get_eos_token_id())}, true ); - m_grammar_compiler = std::make_unique(std::move(tokenizer_info)); - const auto end = std::chrono::steady_clock::now(); - m_init_grammar_compiler_time = std::chrono::duration_cast(end - start).count(); } std::shared_ptr @@ -35,8 +31,6 @@ XGrammarStructuredOutput::get_logits_transformer(const GenerationConfig& samplin auto& guided_gen_config = *sampling_parameters.structured_output_config; guided_gen_config.validate(); - auto start = std::chrono::steady_clock::now(); - xgrammar::Grammar grammar; if (guided_gen_config.json_schema.has_value()) { grammar = xgrammar::Grammar::FromJSONSchema(*guided_gen_config.json_schema); @@ -47,17 +41,9 @@ XGrammarStructuredOutput::get_logits_transformer(const GenerationConfig& samplin } auto compiled_grammar = m_grammar_compiler->CompileGrammar(grammar); - std::vector override_stop_tokens(sampling_parameters.stop_token_ids.begin(), sampling_parameters.stop_token_ids.end()); - auto logit_transformer = std::make_shared(std::move(compiled_grammar), override_stop_tokens); - auto stop = std::chrono::steady_clock::now(); - - // In order to keep all the metrics in one place we store grammaer compiler initilization time - // in the logit_transofmer as well, even though grammar compiler is initialzed prior to - logit_transformer->m_init_grammar_compiler_time = m_init_grammar_compiler_time; - logit_transformer->m_grammar_compile_time = std::chrono::duration_cast(stop - start).count(); - return logit_transformer; + return std::make_shared(std::move(compiled_grammar), override_stop_tokens); } namespace LogitTransformers { diff --git a/src/cpp/src/sampling/structured_output/xgrammar_backend.hpp b/src/cpp/src/sampling/structured_output/xgrammar_backend.hpp index cb1b544af1..e326c6880c 100644 --- a/src/cpp/src/sampling/structured_output/xgrammar_backend.hpp +++ b/src/cpp/src/sampling/structured_output/xgrammar_backend.hpp @@ -24,7 +24,7 @@ namespace LogitTransformers { * which applies grammar constraints to the logits each time a new token is generated, and * accepts tokens to update the internal state of the grammar matcher. */ -class XGrammarLogitsTransformer : public IStatefulLogitTransformer, public LogitTransformersMetrics { +class XGrammarLogitsTransformer : public IStatefulLogitTransformer { public: XGrammarLogitsTransformer( const xgrammar::CompiledGrammar& compiled_grammar, @@ -36,13 +36,6 @@ class XGrammarLogitsTransformer : public IStatefulLogitTransformer, public Logit void accept_tokens(const TokenIds& input_ids) override; void apply(Logits& logits) override; - - std::pair get_structured_output_times() const override { - return {m_init_grammar_compiler_time, m_grammar_compile_time}; - }; - - float m_init_grammar_compiler_time; - float m_grammar_compile_time; protected: xgrammar::GrammarMatcher m_grammar_matcher; @@ -87,7 +80,6 @@ class XGrammarStructuredOutput : public IStructuredOutputImpl { std::shared_ptr get_logits_transformer(const GenerationConfig& sampling_parameters) override; private: std::unique_ptr m_grammar_compiler; - float m_init_grammar_compiler_time; }; diff --git a/src/python/py_perf_metrics.cpp b/src/python/py_perf_metrics.cpp index 3f18de399f..d35440780b 100644 --- a/src/python/py_perf_metrics.cpp +++ b/src/python/py_perf_metrics.cpp @@ -12,6 +12,7 @@ namespace py = pybind11; +using ov::genai::DurationValues; using ov::genai::MeanStdPair; using ov::genai::PerfMetrics; using ov::genai::RawPerfMetrics; @@ -154,6 +155,14 @@ void init_perf_metrics(py::module_& m) { return pyutils::get_ms(rw, &RawPerfMetrics::m_grammar_compile_time); }); + py::class_(m, "DurationValues") + .def(py::init<>()) + .def_readonly("mean", &DurationValues::mean) + .def_readonly("std", &DurationValues::std); + // .def("__iter__", [](const MeanStdPair &self) { + // return py::make_iterator(&self.mean, &self.std + 1); + // }, py::keep_alive<0, 1>()); // Keep object alive while the iterator is used; + py::class_(m, "MeanStdPair") .def(py::init<>()) .def_readonly("mean", &MeanStdPair::mean) diff --git a/src/python/py_utils.cpp b/src/python/py_utils.cpp index 09ea05f202..341648511a 100644 --- a/src/python/py_utils.cpp +++ b/src/python/py_utils.cpp @@ -334,6 +334,8 @@ ov::AnyMap kwargs_to_any_map(const py::kwargs& kwargs) { if (utils::py_object_is_any_map(value) && key == "config") { auto map = utils::py_object_to_any_map(value); params.insert(map.begin(), map.end()); + } else if (py::isinstance(value)) { + params[key] = py::cast(value); } else { if (py::isinstance(value)) { OPENVINO_ASSERT(!py::isinstance(value), "Property \"", key, "\" can't be None."); diff --git a/tests/python_tests/test_llm_pipeline.py b/tests/python_tests/test_llm_pipeline.py index 63227377bd..6ed2cb93fb 100644 --- a/tests/python_tests/test_llm_pipeline.py +++ b/tests/python_tests/test_llm_pipeline.py @@ -870,9 +870,11 @@ def test_perf_metrics_with_structured_output(generation_config, prompt): assert len(raw_metrics.m_grammar_init_time) == 1 assert len(raw_metrics.m_grammar_compile_time) > 0 - assert np.allclose(raw_metrics.m_grammar_init_time * 1000, perf_metrics.get_grammar_init_time().mean) - assert np.allclose(raw_metrics.m_grammar_compile_time * 1000, perf_metrics.get_grammar_compile_time().mean) - + raw_init_times = np.array(raw_metrics.m_grammar_init_time) / 1000 + raw_compile_times = np.array(raw_metrics.m_grammar_compile_time) / 1000 + assert np.allclose(np.mean(raw_init_times), perf_metrics.get_grammar_compiler_init_time().mean) + assert np.allclose(np.mean(raw_compile_times), perf_metrics.get_grammar_compile_time().mean) + # TODO: get max, min values as well @pytest.mark.parametrize("pipeline_type", get_main_pipeline_types()) @pytest.mark.parametrize("stop_str", {True, False}) From 5199c22d625c6b140097a07f7d412c129dcf4932 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Thu, 3 Jul 2025 15:14:48 +0200 Subject: [PATCH 03/10] calculate min, max as well --- .../include/openvino/genai/perf_metrics.hpp | 15 ++++++------ src/cpp/src/perf_metrics.cpp | 17 ++++++++++---- .../openvino_genai/py_openvino_genai.pyi | 23 ++++++++++++++++--- src/python/py_perf_metrics.cpp | 16 +++++++------ tests/python_tests/test_llm_pipeline.py | 3 ++- 5 files changed, 51 insertions(+), 23 deletions(-) diff --git a/src/cpp/include/openvino/genai/perf_metrics.hpp b/src/cpp/include/openvino/genai/perf_metrics.hpp index 5097aa044f..2c1194164c 100644 --- a/src/cpp/include/openvino/genai/perf_metrics.hpp +++ b/src/cpp/include/openvino/genai/perf_metrics.hpp @@ -55,12 +55,11 @@ struct OPENVINO_GENAI_EXPORTS MeanStdPair { /** * @brief Structure to store list of durations in milliseconds. */ -struct OPENVINO_GENAI_EXPORTS DurationValues { - float min; - float max; +struct OPENVINO_GENAI_EXPORTS SummaryStats { float mean; float std; - std::vector values; + float min; + float max; }; /** @@ -117,8 +116,8 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics { MeanStdPair ipot; // Inference time (in ms) per output token. MeanStdPair throughput; // Tokens per second. - MeanStdPair grammar_compiler_init_time; // Time to initialize grammar compiler in ms. - MeanStdPair grammar_compile_time; // Time to compile grammar in ms. + SummaryStats grammar_compiler_init_time; // Time to initialize grammar compiler in ms. + SummaryStats grammar_compile_time; // Time to compile grammar in ms. MeanStdPair generate_duration; MeanStdPair inference_duration; @@ -136,8 +135,8 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics { MeanStdPair get_ipot(); // Inference time (in ms) per output token. MeanStdPair get_throughput(); // Tokens per second. - MeanStdPair get_grammar_compiler_init_time(); - MeanStdPair get_grammar_compile_time(); + SummaryStats get_grammar_compiler_init_time(); + SummaryStats get_grammar_compile_time(); MeanStdPair get_inference_duration(); // in ms MeanStdPair get_generate_duration(); // in ms diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp index 6b92449d01..53cae394ba 100644 --- a/src/cpp/src/perf_metrics.cpp +++ b/src/cpp/src/perf_metrics.cpp @@ -30,6 +30,15 @@ ov::genai::MeanStdPair calc_mean_and_std(const std::vector& durations) { + if (durations.size() == 0) { + return {-1, -1, -1, -1}; + } + auto minmax = std::minmax_element(durations.begin(), durations.end()); + auto meanstd = calc_mean_and_std(durations); + return {minmax.first->count() / 1000.0f, minmax.second->count() / 1000.0f, meanstd.mean, meanstd.std}; +} + float PerfMetrics::get_load_time() { return load_time; } @@ -84,12 +93,12 @@ MeanStdPair PerfMetrics::get_inference_duration() { return inference_duration; } -MeanStdPair PerfMetrics::get_grammar_compiler_init_time() { +SummaryStats PerfMetrics::get_grammar_compiler_init_time() { evaluate_statistics(); return grammar_compiler_init_time; } -MeanStdPair PerfMetrics::get_grammar_compile_time() { +SummaryStats PerfMetrics::get_grammar_compile_time() { evaluate_statistics(); return grammar_compile_time; } @@ -131,8 +140,8 @@ void PerfMetrics::evaluate_statistics(std::optional start_time) { ipot = calc_mean_and_std(raw_metrics.m_token_infer_durations); ttft = calc_mean_and_std(raw_metrics.m_times_to_first_token); - grammar_compiler_init_time = calc_mean_and_std(raw_metrics.m_grammar_init_time); - grammar_compile_time = calc_mean_and_std(raw_metrics.m_grammar_compile_time); + grammar_compile_time = calc_full_stat(raw_metrics.m_grammar_compile_time); + grammar_compiler_init_time = calc_full_stat(raw_metrics.m_grammar_init_time); generate_duration = calc_mean_and_std(raw_metrics.generate_durations); tokenization_duration = calc_mean_and_std(raw_metrics.tokenization_durations); diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index 3547592045..f4a41b4045 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -5,7 +5,7 @@ from __future__ import annotations import openvino._pyopenvino import os import typing -__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuredOutputConfig', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version'] +__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuredOutputConfig', 'SummaryStats', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version'] class Adapter: """ Immutable LoRA Adapter that carries the adaptation matrices and serves as unique adapter identifier. @@ -1343,9 +1343,9 @@ class PerfMetrics: ... def get_generate_duration(self) -> MeanStdPair: ... - def get_grammar_compile_time(self) -> MeanStdPair: + def get_grammar_compile_time(self) -> SummaryStats: ... - def get_grammar_compiler_init_time(self) -> MeanStdPair: + def get_grammar_compiler_init_time(self) -> SummaryStats: ... def get_inference_duration(self) -> MeanStdPair: ... @@ -1847,6 +1847,23 @@ class StructuredOutputConfig: @regex.setter def regex(self, arg0: str | None) -> None: ... +class SummaryStats: + def __init__(self) -> None: + ... + def as_tuple(self) -> tuple: + ... + @property + def max(self) -> float: + ... + @property + def mean(self) -> float: + ... + @property + def min(self) -> float: + ... + @property + def std(self) -> float: + ... class T5EncoderModel: """ T5EncoderModel class. diff --git a/src/python/py_perf_metrics.cpp b/src/python/py_perf_metrics.cpp index d35440780b..b55868a3dd 100644 --- a/src/python/py_perf_metrics.cpp +++ b/src/python/py_perf_metrics.cpp @@ -12,7 +12,7 @@ namespace py = pybind11; -using ov::genai::DurationValues; +using ov::genai::SummaryStats; using ov::genai::MeanStdPair; using ov::genai::PerfMetrics; using ov::genai::RawPerfMetrics; @@ -155,13 +155,15 @@ void init_perf_metrics(py::module_& m) { return pyutils::get_ms(rw, &RawPerfMetrics::m_grammar_compile_time); }); - py::class_(m, "DurationValues") + py::class_(m, "SummaryStats") .def(py::init<>()) - .def_readonly("mean", &DurationValues::mean) - .def_readonly("std", &DurationValues::std); - // .def("__iter__", [](const MeanStdPair &self) { - // return py::make_iterator(&self.mean, &self.std + 1); - // }, py::keep_alive<0, 1>()); // Keep object alive while the iterator is used; + .def_readonly("mean", &SummaryStats::mean) + .def_readonly("std", &SummaryStats::std) + .def_readonly("min", &SummaryStats::min) + .def_readonly("max", &SummaryStats::max) + .def("as_tuple", [](const SummaryStats& self) { + return py::make_tuple(self.mean, self.std, self.min, self.max); + }); py::class_(m, "MeanStdPair") .def(py::init<>()) diff --git a/tests/python_tests/test_llm_pipeline.py b/tests/python_tests/test_llm_pipeline.py index 6ed2cb93fb..2f9e694cdf 100644 --- a/tests/python_tests/test_llm_pipeline.py +++ b/tests/python_tests/test_llm_pipeline.py @@ -874,7 +874,8 @@ def test_perf_metrics_with_structured_output(generation_config, prompt): raw_compile_times = np.array(raw_metrics.m_grammar_compile_time) / 1000 assert np.allclose(np.mean(raw_init_times), perf_metrics.get_grammar_compiler_init_time().mean) assert np.allclose(np.mean(raw_compile_times), perf_metrics.get_grammar_compile_time().mean) - # TODO: get max, min values as well + assert np.allclose(np.max(raw_init_times), perf_metrics.get_grammar_compiler_init_time().max) + assert np.allclose(np.min(raw_init_times), perf_metrics.get_grammar_compiler_init_time().min) @pytest.mark.parametrize("pipeline_type", get_main_pipeline_types()) @pytest.mark.parametrize("stop_str", {True, False}) From 1cdc562e337313913bb04b6a8de784f8c5c516ba Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Fri, 4 Jul 2025 12:14:50 +0200 Subject: [PATCH 04/10] simplify tests, store single value for grammar compiler init --- .../include/openvino/genai/perf_metrics.hpp | 14 ++-- .../src/continuous_batching/pipeline_impl.cpp | 4 +- src/cpp/src/perf_metrics.cpp | 11 +-- .../structured_output_controller.cpp | 4 + .../structured_output_controller.hpp | 1 + src/python/py_perf_metrics.cpp | 10 +-- tests/python_tests/test_llm_pipeline.py | 77 ++----------------- 7 files changed, 32 insertions(+), 89 deletions(-) diff --git a/src/cpp/include/openvino/genai/perf_metrics.hpp b/src/cpp/include/openvino/genai/perf_metrics.hpp index 2c1194164c..8970bfc3e3 100644 --- a/src/cpp/include/openvino/genai/perf_metrics.hpp +++ b/src/cpp/include/openvino/genai/perf_metrics.hpp @@ -27,6 +27,7 @@ using MicroSeconds = std::chrono::duration>; * @param m_batch_sizes Batch sizes for each generate call. * @param m_durations Total durations for each generate call in microseconds. * @param m_inference_durations Total inference duration for each generate call in microseconds. + * @param m_grammar_compile_times Time to compile the grammar in microseconds. */ struct OPENVINO_GENAI_EXPORTS RawPerfMetrics { std::vector generate_durations; @@ -40,8 +41,7 @@ struct OPENVINO_GENAI_EXPORTS RawPerfMetrics { std::vector m_durations; std::vector m_inference_durations; - std::vector m_grammar_init_time; - std::vector m_grammar_compile_time; + std::vector m_grammar_compile_times; }; /** @@ -89,6 +89,8 @@ struct OPENVINO_GENAI_EXPORTS SummaryStats { * @param get_generate_duration Returns the mean and standard deviation of generate duration. * @param get_tokenization_duration Returns the mean and standard deviation of tokenization duration. * @param get_detokenization_duration Returns the mean and standard deviation of detokenization duration. + * @param get_grammar_compiler_init_time Returns the time to initialize the grammar compiler in milliseconds. + * @param get_grammar_compile_time Returns the time to compile the grammar in milliseconds. * @param get_microsec Converts a duration to microseconds. * @param m_evaluated Flag indicating if raw metrics were evaluated. * If false, current mean/std TTFT, TPOT, etc. are not actual and evaluate_statistics() should recalculate them. @@ -116,8 +118,8 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics { MeanStdPair ipot; // Inference time (in ms) per output token. MeanStdPair throughput; // Tokens per second. - SummaryStats grammar_compiler_init_time; // Time to initialize grammar compiler in ms. - SummaryStats grammar_compile_time; // Time to compile grammar in ms. + float grammar_compiler_init_time; // Time to initialize grammar compiler in ms. + SummaryStats grammar_compile_time; // Time to compile grammar in ms. MeanStdPair generate_duration; MeanStdPair inference_duration; @@ -135,8 +137,8 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics { MeanStdPair get_ipot(); // Inference time (in ms) per output token. MeanStdPair get_throughput(); // Tokens per second. - SummaryStats get_grammar_compiler_init_time(); - SummaryStats get_grammar_compile_time(); + float get_grammar_compiler_init_time(); // in ms + SummaryStats get_grammar_compile_time(); // in ms MeanStdPair get_inference_duration(); // in ms MeanStdPair get_generate_duration(); // in ms diff --git a/src/cpp/src/continuous_batching/pipeline_impl.cpp b/src/cpp/src/continuous_batching/pipeline_impl.cpp index 7c86851dc2..be2f269ac1 100644 --- a/src/cpp/src/continuous_batching/pipeline_impl.cpp +++ b/src/cpp/src/continuous_batching/pipeline_impl.cpp @@ -446,9 +446,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorget_structured_output_times(); - raw_perf_counters.m_grammar_init_time.emplace_back(times.first); + perf_metrics.grammar_compiler_init_time = times.first; for (const auto& t: times.second) { - raw_perf_counters.m_grammar_compile_time.emplace_back(t); + raw_perf_counters.m_grammar_compile_times.emplace_back(t); } // waiting for competion of streaming diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp index 53cae394ba..06cb7cba87 100644 --- a/src/cpp/src/perf_metrics.cpp +++ b/src/cpp/src/perf_metrics.cpp @@ -36,7 +36,7 @@ ov::genai::SummaryStats calc_full_stat(const std::vectorcount() / 1000.0f, minmax.second->count() / 1000.0f, meanstd.mean, meanstd.std}; + return {meanstd.mean, meanstd.std, minmax.first->count() / 1000.0f, minmax.second->count() / 1000.0f}; } float PerfMetrics::get_load_time() { @@ -93,7 +93,7 @@ MeanStdPair PerfMetrics::get_inference_duration() { return inference_duration; } -SummaryStats PerfMetrics::get_grammar_compiler_init_time() { +float PerfMetrics::get_grammar_compiler_init_time() { evaluate_statistics(); return grammar_compiler_init_time; } @@ -140,8 +140,7 @@ void PerfMetrics::evaluate_statistics(std::optional start_time) { ipot = calc_mean_and_std(raw_metrics.m_token_infer_durations); ttft = calc_mean_and_std(raw_metrics.m_times_to_first_token); - grammar_compile_time = calc_full_stat(raw_metrics.m_grammar_compile_time); - grammar_compiler_init_time = calc_full_stat(raw_metrics.m_grammar_init_time); + grammar_compile_time = calc_full_stat(raw_metrics.m_grammar_compile_times); generate_duration = calc_mean_and_std(raw_metrics.generate_durations); tokenization_duration = calc_mean_and_std(raw_metrics.tokenization_durations); @@ -155,7 +154,9 @@ void PerfMetrics::evaluate_statistics(std::optional start_time) { PerfMetrics PerfMetrics::operator+(const PerfMetrics& right) const { OPENVINO_ASSERT(right.load_time == load_time, "generation metrics can be accumulated only for the same pipeline"); - + // TODO: rewrite this map will be used + OPENVINO_ASSERT(right.grammar_compiler_init_time == grammar_compiler_init_time, "grammar compile time can be accumulated only for the same pipeline"); + // Copy left value to res. PerfMetrics res = *this; diff --git a/src/cpp/src/sampling/structured_output/structured_output_controller.cpp b/src/cpp/src/sampling/structured_output/structured_output_controller.cpp index f440e63fbf..b1601b252c 100644 --- a/src/cpp/src/sampling/structured_output/structured_output_controller.cpp +++ b/src/cpp/src/sampling/structured_output/structured_output_controller.cpp @@ -41,6 +41,8 @@ StructuredOutputController::get_logits_transformer(const ov::genai::GenerationCo } std::string backend_name = (*guided_gen_config).backend.value_or(get_default_backend_name()); + std::unique_lock lock(m_mutex); + // Check if backend already instantiated auto impl_it = m_impls.find(backend_name); if (impl_it == m_impls.end()) { @@ -56,6 +58,8 @@ StructuredOutputController::get_logits_transformer(const ov::genai::GenerationCo m_impls[backend_name] = factory_it->second(m_tokenizer, m_vocab_size); impl_it = m_impls.find(backend_name); const auto end = std::chrono::steady_clock::now(); + // TODO: save into map and return as a single value not a list + // TODO: add mutex m_init_grammar_compiler_time = std::chrono::duration_cast(end - start).count(); } diff --git a/src/cpp/src/sampling/structured_output/structured_output_controller.hpp b/src/cpp/src/sampling/structured_output/structured_output_controller.hpp index 1ca1fcb45e..f18b9c9a29 100644 --- a/src/cpp/src/sampling/structured_output/structured_output_controller.hpp +++ b/src/cpp/src/sampling/structured_output/structured_output_controller.hpp @@ -82,6 +82,7 @@ class StructuredOutputController { std::unordered_map> m_impls; const ov::genai::Tokenizer& m_tokenizer; std::optional m_vocab_size; + std::mutex m_mutex; }; } // namespace genai diff --git a/src/python/py_perf_metrics.cpp b/src/python/py_perf_metrics.cpp index b55868a3dd..7d75ad3486 100644 --- a/src/python/py_perf_metrics.cpp +++ b/src/python/py_perf_metrics.cpp @@ -50,6 +50,9 @@ auto raw_perf_metrics_docstring = R"( :param inference_durations : Total inference duration for each generate call in milliseconds. :type batch_sizes: list[float] + + :param grammar_compile_times: Time to compile the grammar in milliseconds. + :type grammar_compile_times: list[float] )"; auto perf_metrics_docstring = R"( @@ -148,11 +151,8 @@ void init_perf_metrics(py::module_& m) { .def_property_readonly("inference_durations", [](const RawPerfMetrics &rw) { return pyutils::get_ms(rw, &RawPerfMetrics::m_inference_durations); }) - .def_property_readonly("m_grammar_init_time", [](const RawPerfMetrics &rw) { - return pyutils::get_ms(rw, &RawPerfMetrics::m_grammar_init_time); - }) - .def_property_readonly("m_grammar_compile_time", [](const RawPerfMetrics &rw) { - return pyutils::get_ms(rw, &RawPerfMetrics::m_grammar_compile_time); + .def_property_readonly("grammar_compile_times", [](const RawPerfMetrics &rw) { + return pyutils::get_ms(rw, &RawPerfMetrics::m_grammar_compile_times); }); py::class_(m, "SummaryStats") diff --git a/tests/python_tests/test_llm_pipeline.py b/tests/python_tests/test_llm_pipeline.py index 2f9e694cdf..e08ff4a941 100644 --- a/tests/python_tests/test_llm_pipeline.py +++ b/tests/python_tests/test_llm_pipeline.py @@ -801,81 +801,16 @@ def test_perf_metrics_with_structured_output(generation_config, prompt): start_time = time.perf_counter() model_id = 'katuni4ka/tiny-random-gemma2' perf_metrics = run_perf_metrics_collection(model_id, generation_config, prompt) - total_time = (time.perf_counter() - start_time) * 1000 - - # Check that load time is adequate. - load_time = perf_metrics.get_load_time() - assert load_time > 0 and load_time < total_time - - # Check that num input and generated tokens are adequate. - num_generated_tokens = perf_metrics.get_num_generated_tokens() - assert num_generated_tokens > 0 and num_generated_tokens <= generation_config['max_new_tokens'] - - num_input_tokens = perf_metrics.get_num_input_tokens() - assert num_input_tokens > 0 and num_input_tokens <= len(prompt) - - mean_ttft, std_ttft = perf_metrics.get_ttft() - assert (mean_ttft, std_ttft) == (perf_metrics.get_ttft().mean, perf_metrics.get_ttft().std) - assert mean_ttft > 0 and mean_ttft < 1000.0 - raw_metrics = perf_metrics.raw_metrics - durations = np.array(raw_metrics.m_durations) / 1000 - # Check that prefill is not included in durations for TPOT calculation. - # For the very long prompt prefill is slow and TTFT is much larger than any other token generation duration. - assert np.all(mean_ttft > durations * 2) - - mean_tpot, std_tpot = perf_metrics.get_tpot() - assert (mean_tpot, std_tpot) == (perf_metrics.get_tpot().mean, perf_metrics.get_tpot().std) - assert mean_tpot > 0 and mean_ttft < 1000.0 - - mean_throughput, std_throughput = perf_metrics.get_throughput() - assert (mean_throughput, std_throughput) == (perf_metrics.get_throughput().mean, perf_metrics.get_throughput().std) - assert mean_throughput > 0 and mean_throughput < 20000.0 - - mean_gen_duration, std_gen_duration = perf_metrics.get_generate_duration() - assert (mean_gen_duration, std_gen_duration) == (perf_metrics.get_generate_duration().mean, perf_metrics.get_generate_duration().std) - assert mean_gen_duration > 0 and load_time + mean_gen_duration < total_time - assert std_gen_duration == 0 - - mean_tok_duration, std_tok_duration = perf_metrics.get_tokenization_duration() - assert (mean_tok_duration, std_tok_duration) == (perf_metrics.get_tokenization_duration().mean, perf_metrics.get_tokenization_duration().std) - assert mean_tok_duration > 0 and mean_tok_duration < mean_gen_duration - assert std_tok_duration == 0 - - mean_detok_duration, std_detok_duration = perf_metrics.get_detokenization_duration() - assert (mean_detok_duration, std_detok_duration) == (perf_metrics.get_detokenization_duration().mean, perf_metrics.get_detokenization_duration().std) - assert mean_detok_duration > 0 and mean_detok_duration < mean_gen_duration - assert std_detok_duration == 0 - - # assert that calculating statistics manually from the raw counters we get the same restults as from PerfMetrics - assert np.allclose(mean_tpot, np.mean(durations)) - assert np.allclose(std_tpot, np.std(durations)) - - raw_dur = np.array(raw_metrics.generate_durations) / 1000 - assert np.allclose(mean_gen_duration, np.mean(raw_dur)) - assert np.allclose(std_gen_duration, np.std(raw_dur)) - - raw_dur = np.array(raw_metrics.tokenization_durations) / 1000 - assert np.allclose(mean_tok_duration, np.mean(raw_dur)) - assert np.allclose(std_tok_duration, np.std(raw_dur)) - - raw_dur = np.array(raw_metrics.detokenization_durations) / 1000 - assert np.allclose(mean_detok_duration, np.mean(raw_dur)) - assert np.allclose(std_detok_duration, np.std(raw_dur)) - - assert len(raw_metrics.m_times_to_first_token) > 0 - assert len(raw_metrics.m_batch_sizes) > 0 - assert len(raw_metrics.m_durations) > 0 - assert len(raw_metrics.m_grammar_init_time) == 1 - assert len(raw_metrics.m_grammar_compile_time) > 0 + assert perf_metrics.get_grammar_compiler_init_time() > 0.0 + assert len(raw_metrics.grammar_compile_times) > 0 - raw_init_times = np.array(raw_metrics.m_grammar_init_time) / 1000 - raw_compile_times = np.array(raw_metrics.m_grammar_compile_time) / 1000 - assert np.allclose(np.mean(raw_init_times), perf_metrics.get_grammar_compiler_init_time().mean) + raw_compile_times = np.array(raw_metrics.grammar_compile_times) / 1000 assert np.allclose(np.mean(raw_compile_times), perf_metrics.get_grammar_compile_time().mean) - assert np.allclose(np.max(raw_init_times), perf_metrics.get_grammar_compiler_init_time().max) - assert np.allclose(np.min(raw_init_times), perf_metrics.get_grammar_compiler_init_time().min) + assert np.allclose(np.std(raw_compile_times), perf_metrics.get_grammar_compile_time().std) + assert np.allclose(np.min(raw_compile_times), perf_metrics.get_grammar_compile_time().min) + assert np.allclose(np.max(raw_compile_times), perf_metrics.get_grammar_compile_time().max) @pytest.mark.parametrize("pipeline_type", get_main_pipeline_types()) @pytest.mark.parametrize("stop_str", {True, False}) From df262b4114215f7beec9304e31e330b4a4d84a27 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Fri, 4 Jul 2025 12:44:23 +0200 Subject: [PATCH 05/10] use map grammar_compiler_init_times --- .../include/openvino/genai/perf_metrics.hpp | 9 +++-- .../src/continuous_batching/pipeline_impl.cpp | 2 +- src/cpp/src/perf_metrics.cpp | 24 +++++++++--- src/cpp/src/sampling/sampler.cpp | 8 ++-- src/cpp/src/sampling/sampler.hpp | 3 +- .../structured_output_controller.cpp | 13 +++++-- .../structured_output_controller.hpp | 6 ++- .../openvino_genai/py_openvino_genai.pyi | 2 +- src/python/py_perf_metrics.cpp | 8 +++- tests/python_tests/test_llm_pipeline.py | 39 +++++++++++-------- 10 files changed, 77 insertions(+), 37 deletions(-) diff --git a/src/cpp/include/openvino/genai/perf_metrics.hpp b/src/cpp/include/openvino/genai/perf_metrics.hpp index 8970bfc3e3..00cf31c359 100644 --- a/src/cpp/include/openvino/genai/perf_metrics.hpp +++ b/src/cpp/include/openvino/genai/perf_metrics.hpp @@ -7,6 +7,8 @@ #include "openvino/genai/visibility.hpp" #include #include +#include +#include #include namespace ov { @@ -89,7 +91,7 @@ struct OPENVINO_GENAI_EXPORTS SummaryStats { * @param get_generate_duration Returns the mean and standard deviation of generate duration. * @param get_tokenization_duration Returns the mean and standard deviation of tokenization duration. * @param get_detokenization_duration Returns the mean and standard deviation of detokenization duration. - * @param get_grammar_compiler_init_time Returns the time to initialize the grammar compiler in milliseconds. + * @param get_grammar_compiler_init_times Returns a map with the time to initialize the grammar compiler for each backend in milliseconds. * @param get_grammar_compile_time Returns the time to compile the grammar in milliseconds. * @param get_microsec Converts a duration to microseconds. * @param m_evaluated Flag indicating if raw metrics were evaluated. @@ -118,7 +120,8 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics { MeanStdPair ipot; // Inference time (in ms) per output token. MeanStdPair throughput; // Tokens per second. - float grammar_compiler_init_time; // Time to initialize grammar compiler in ms. + // Time to initialize grammar compiler for each backend in ms. + std::map grammar_compiler_init_times; SummaryStats grammar_compile_time; // Time to compile grammar in ms. MeanStdPair generate_duration; @@ -137,7 +140,7 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics { MeanStdPair get_ipot(); // Inference time (in ms) per output token. MeanStdPair get_throughput(); // Tokens per second. - float get_grammar_compiler_init_time(); // in ms + std::map get_grammar_compiler_init_times(); SummaryStats get_grammar_compile_time(); // in ms MeanStdPair get_inference_duration(); // in ms diff --git a/src/cpp/src/continuous_batching/pipeline_impl.cpp b/src/cpp/src/continuous_batching/pipeline_impl.cpp index be2f269ac1..575ee14339 100644 --- a/src/cpp/src/continuous_batching/pipeline_impl.cpp +++ b/src/cpp/src/continuous_batching/pipeline_impl.cpp @@ -446,7 +446,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorget_structured_output_times(); - perf_metrics.grammar_compiler_init_time = times.first; + perf_metrics.grammar_compiler_init_times = times.first; for (const auto& t: times.second) { raw_perf_counters.m_grammar_compile_times.emplace_back(t); } diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp index 06cb7cba87..f14884f11e 100644 --- a/src/cpp/src/perf_metrics.cpp +++ b/src/cpp/src/perf_metrics.cpp @@ -93,9 +93,8 @@ MeanStdPair PerfMetrics::get_inference_duration() { return inference_duration; } -float PerfMetrics::get_grammar_compiler_init_time() { - evaluate_statistics(); - return grammar_compiler_init_time; +std::map PerfMetrics::get_grammar_compiler_init_times() { + return grammar_compiler_init_times; } SummaryStats PerfMetrics::get_grammar_compile_time() { @@ -154,11 +153,22 @@ void PerfMetrics::evaluate_statistics(std::optional start_time) { PerfMetrics PerfMetrics::operator+(const PerfMetrics& right) const { OPENVINO_ASSERT(right.load_time == load_time, "generation metrics can be accumulated only for the same pipeline"); - // TODO: rewrite this map will be used - OPENVINO_ASSERT(right.grammar_compiler_init_time == grammar_compiler_init_time, "grammar compile time can be accumulated only for the same pipeline"); // Copy left value to res. PerfMetrics res = *this; + + // maps with grammar compiler init times should not have conflicting keys + // {{"xgrammar", 10}} + {{"llmguidance", 20}} = {{"grammar", 10}, {"llmguidance", 20}} - is OK! + // {{"xgrammar", 10}} + {{"xgrammar", 10}} = {{"xgrammar", 10}} - is OK! + // {{"xgrammar", 10}} + {{"xgrammar", 20}} = is NOT OK! Fails on assert! + for (const auto& [key, value] : right.grammar_compiler_init_times) { + auto it = res.grammar_compiler_init_times.find(key); + if (it != res.grammar_compiler_init_times.end()) { + OPENVINO_ASSERT(it->second == value, "Grammar compiler init time for the same backend should be the same. ", + "You are trying to accumulate metrics for different pipelines which is not allowed."); + } + res.grammar_compiler_init_times[key] = value; + } // Concatenate durations, batch_sizes first token times. auto& new_durations = res.raw_metrics.m_durations; @@ -190,6 +200,10 @@ PerfMetrics PerfMetrics::operator+(const PerfMetrics& right) const { new_detok_durations.insert(new_detok_durations.end(), right_detok_durations.begin(), right_detok_durations.end()); new_gen_durations.insert(new_gen_durations.end(), right_gen_durations.begin(), right_gen_durations.end()); + // Concatenate structured output compilation times. + auto& new_grammar_compile_times = res.raw_metrics.m_grammar_compile_times; + new_grammar_compile_times.insert(new_grammar_compile_times.end(), right.raw_metrics.m_grammar_compile_times.begin(), right.raw_metrics.m_grammar_compile_times.end()); + res.num_generated_tokens += right.num_generated_tokens; res.num_input_tokens += right.num_input_tokens; res.m_evaluated = false; diff --git a/src/cpp/src/sampling/sampler.cpp b/src/cpp/src/sampling/sampler.cpp index e54681b239..7a593dc6e1 100644 --- a/src/cpp/src/sampling/sampler.cpp +++ b/src/cpp/src/sampling/sampler.cpp @@ -241,18 +241,18 @@ std::map Sampler::GroupBeamSearcher::get_beam_idxs() { return next_beams; } -std::pair> Sampler::get_structured_output_times() { +std::pair, std::vector> Sampler::get_structured_output_times() { if (m_structured_output_controller) { - return {m_structured_output_controller->m_init_grammar_compiler_time, m_structured_output_controller->m_grammar_compile_times}; + return m_structured_output_controller->get_times(); } else { // If compiled without structured output support, return empty times - return {0.0f, {}}; + return {{}, {}}; } } void Sampler::clear_structured_output_compile_times() { if (m_structured_output_controller) { - m_structured_output_controller->m_grammar_compile_times.clear(); + m_structured_output_controller->clear_compile_times(); } } diff --git a/src/cpp/src/sampling/sampler.hpp b/src/cpp/src/sampling/sampler.hpp index 1e64040b1d..4e6129fba2 100644 --- a/src/cpp/src/sampling/sampler.hpp +++ b/src/cpp/src/sampling/sampler.hpp @@ -123,7 +123,8 @@ class Sampler { void create_logit_processor(uint64_t request_id, const GenerationConfig& sampling_parameters, const TokenIds& prompt); std::map get_beam_idxs(SequenceGroup::CPtr sequence_group); - std::pair> get_structured_output_times(); + // pair with map with backend name and corresponding acompiler init time, and vector of compile times for each concrete grammar + std::pair, std::vector> get_structured_output_times(); void clear_structured_output_compile_times(); }; diff --git a/src/cpp/src/sampling/structured_output/structured_output_controller.cpp b/src/cpp/src/sampling/structured_output/structured_output_controller.cpp index b1601b252c..cf569d28e4 100644 --- a/src/cpp/src/sampling/structured_output/structured_output_controller.cpp +++ b/src/cpp/src/sampling/structured_output/structured_output_controller.cpp @@ -58,9 +58,7 @@ StructuredOutputController::get_logits_transformer(const ov::genai::GenerationCo m_impls[backend_name] = factory_it->second(m_tokenizer, m_vocab_size); impl_it = m_impls.find(backend_name); const auto end = std::chrono::steady_clock::now(); - // TODO: save into map and return as a single value not a list - // TODO: add mutex - m_init_grammar_compiler_time = std::chrono::duration_cast(end - start).count(); + m_init_grammar_compiler_times[backend_name] = std::chrono::duration_cast(end - start).count(); } // Use the instantiated backend @@ -71,5 +69,14 @@ StructuredOutputController::get_logits_transformer(const ov::genai::GenerationCo return res; } +std::pair, std::vector> StructuredOutputController::get_times() const { + return {m_init_grammar_compiler_times, m_grammar_compile_times}; +} + +void StructuredOutputController::clear_compile_times() { + std::lock_guard lock(m_mutex); + m_grammar_compile_times.clear(); +} + } // namespace genai } // namespace ov diff --git a/src/cpp/src/sampling/structured_output/structured_output_controller.hpp b/src/cpp/src/sampling/structured_output/structured_output_controller.hpp index f18b9c9a29..c40e1b99e3 100644 --- a/src/cpp/src/sampling/structured_output/structured_output_controller.hpp +++ b/src/cpp/src/sampling/structured_output/structured_output_controller.hpp @@ -76,9 +76,11 @@ class StructuredOutputController { static std::string& get_default_backend_name(); static std::unordered_map& get_backend_registry(); - float m_init_grammar_compiler_time; - std::vector m_grammar_compile_times; + std::pair, std::vector> get_times() const; + void clear_compile_times(); private: + std::map m_init_grammar_compiler_times; + std::vector m_grammar_compile_times; std::unordered_map> m_impls; const ov::genai::Tokenizer& m_tokenizer; std::optional m_vocab_size; diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index f4a41b4045..9200ea9347 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -1345,7 +1345,7 @@ class PerfMetrics: ... def get_grammar_compile_time(self) -> SummaryStats: ... - def get_grammar_compiler_init_time(self) -> SummaryStats: + def get_grammar_compiler_init_times(self) -> SummaryStats: ... def get_inference_duration(self) -> MeanStdPair: ... diff --git a/src/python/py_perf_metrics.cpp b/src/python/py_perf_metrics.cpp index 7d75ad3486..00ec520c31 100644 --- a/src/python/py_perf_metrics.cpp +++ b/src/python/py_perf_metrics.cpp @@ -101,6 +101,12 @@ auto perf_metrics_docstring = R"( :param get_detokenization_duration: Returns the mean and standard deviation of detokenization durations in milliseconds. :type get_detokenization_duration: MeanStdPair + :param get_grammar_compiler_init_times: Returns a map with the time to initialize the grammar compiler for each backend in milliseconds. + :type get_grammar_compiler_init_times: dict[str, float] + + :param get_grammar_compile_time: Returns the mean, standard deviation, min, and max of grammar compile times in milliseconds. + :type get_grammar_compile_time: SummaryStats + :param raw_metrics: A structure of RawPerfMetrics type that holds raw metrics. :type raw_metrics: RawPerfMetrics )"; @@ -176,7 +182,7 @@ void init_perf_metrics(py::module_& m) { py::class_(m, "PerfMetrics", perf_metrics_docstring) .def(py::init<>()) .def("get_load_time", &PerfMetrics::get_load_time) - .def("get_grammar_compiler_init_time", &PerfMetrics::get_grammar_compiler_init_time) + .def("get_grammar_compiler_init_times", &PerfMetrics::get_grammar_compiler_init_times) .def("get_grammar_compile_time", &PerfMetrics::get_grammar_compile_time) .def("get_num_generated_tokens", &PerfMetrics::get_num_generated_tokens) .def("get_num_input_tokens", &PerfMetrics::get_num_input_tokens) diff --git a/tests/python_tests/test_llm_pipeline.py b/tests/python_tests/test_llm_pipeline.py index e08ff4a941..1ac4382fc2 100644 --- a/tests/python_tests/test_llm_pipeline.py +++ b/tests/python_tests/test_llm_pipeline.py @@ -10,6 +10,8 @@ import json import numpy as np from pathlib import Path +from typing import Literal +from pydantic import BaseModel, Field import openvino as ov import openvino_genai as ov_genai @@ -778,32 +780,30 @@ def test_perf_metrics(generation_config, prompt): assert len(raw_metrics.m_batch_sizes) > 0 assert len(raw_metrics.m_durations) > 0 -import json -from typing import Literal -from pydantic import BaseModel, Field - - -class Person(BaseModel): - name: str = Field(pattern=r"^[A-Z][a-z]{1,20}$") - surname: str = Field(pattern=r"^[A-Z][a-z]{1,20}$") - age: int - city: Literal["Dublin", "Dubai", "Munich"] test_cases = [ - (dict(max_new_tokens=20, structured_output_config=ov_genai.StructuredOutputConfig(json_schema=json.dumps(Person.model_json_schema()))), - 'Generate json of a person'), + (dict(max_new_tokens=20), 'Generate json of a person'), ] @pytest.mark.parametrize("generation_config,prompt", test_cases) @pytest.mark.precommit @pytest.mark.nightly def test_perf_metrics_with_structured_output(generation_config, prompt): - import time - start_time = time.perf_counter() + class Person(BaseModel): + name: str = Field(pattern=r"^[A-Z][a-z]{1,20}$") + surname: str = Field(pattern=r"^[A-Z][a-z]{1,20}$") + age: int + city: Literal["Dublin", "Dubai", "Munich"] + generation_config.update(dict(structured_output_config=ov_genai.StructuredOutputConfig(json_schema=json.dumps(Person.model_json_schema())))) + model_id = 'katuni4ka/tiny-random-gemma2' - perf_metrics = run_perf_metrics_collection(model_id, generation_config, prompt) + _, _, models_path = download_and_convert_model(model_id) + ov_pipe = create_ov_pipeline(models_path) + perf_metrics = ov_pipe.generate([prompt], **generation_config).perf_metrics raw_metrics = perf_metrics.raw_metrics - assert perf_metrics.get_grammar_compiler_init_time() > 0.0 + assert len(perf_metrics.get_grammar_compiler_init_times()) > 0 + assert 'xgrammar' in perf_metrics.get_grammar_compiler_init_times() and perf_metrics.get_grammar_compiler_init_times()['xgrammar'] > 0.0 + assert len(raw_metrics.grammar_compile_times) > 0 raw_compile_times = np.array(raw_metrics.grammar_compile_times) / 1000 @@ -812,6 +812,13 @@ def test_perf_metrics_with_structured_output(generation_config, prompt): assert np.allclose(np.min(raw_compile_times), perf_metrics.get_grammar_compile_time().min) assert np.allclose(np.max(raw_compile_times), perf_metrics.get_grammar_compile_time().max) + # Check that metrics are correctly accumulated/concatenated + perf_metrics_2 = ov_pipe.generate([prompt], **generation_config).perf_metrics + raw_metrics_2 = perf_metrics_2.raw_metrics + accumulated_metrics = perf_metrics + perf_metrics_2 + assert accumulated_metrics.raw_metrics.grammar_compile_times == raw_metrics.grammar_compile_times + raw_metrics_2.grammar_compile_times + + @pytest.mark.parametrize("pipeline_type", get_main_pipeline_types()) @pytest.mark.parametrize("stop_str", {True, False}) @pytest.mark.precommit From 5dcdf2bb3a7a84bdf847bc221408468310518857 Mon Sep 17 00:00:00 2001 From: Alina Kladieva Date: Fri, 4 Jul 2025 14:34:37 +0200 Subject: [PATCH 06/10] Add issues: write permissions for labeler --- .github/workflows/labeler.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 102dee120b..2f25c4ebbe 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -9,6 +9,7 @@ jobs: permissions: contents: read pull-requests: write + issues: write runs-on: ubuntu-latest steps: - uses: akladiev/labeler@eeac5941e7fb6f980d47e038ac0665168851c874 # v4.3.1 From 894ea6e58804993a2a28bc5644dd449864000552 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Fri, 4 Jul 2025 15:40:05 +0200 Subject: [PATCH 07/10] Update src/cpp/src/sampling/sampler.hpp Co-authored-by: Artur Paniukov --- src/cpp/src/sampling/sampler.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/src/sampling/sampler.hpp b/src/cpp/src/sampling/sampler.hpp index 4e6129fba2..52faa79702 100644 --- a/src/cpp/src/sampling/sampler.hpp +++ b/src/cpp/src/sampling/sampler.hpp @@ -123,7 +123,7 @@ class Sampler { void create_logit_processor(uint64_t request_id, const GenerationConfig& sampling_parameters, const TokenIds& prompt); std::map get_beam_idxs(SequenceGroup::CPtr sequence_group); - // pair with map with backend name and corresponding acompiler init time, and vector of compile times for each concrete grammar + // pair with map with backend name and corresponding compiler init time, and vector of compile times for each concrete grammar std::pair, std::vector> get_structured_output_times(); void clear_structured_output_compile_times(); }; From ac9604f99ffda9869473f2a89860448ced603206 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Fri, 4 Jul 2025 16:19:06 +0200 Subject: [PATCH 08/10] update pybind-stubgen --- src/cpp/src/generation_config.cpp | 2 +- .../openvino_genai/py_openvino_genai.pyi | 20 ++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/cpp/src/generation_config.cpp b/src/cpp/src/generation_config.cpp index b8c72b8790..d3cb461830 100644 --- a/src/cpp/src/generation_config.cpp +++ b/src/cpp/src/generation_config.cpp @@ -150,7 +150,7 @@ void GenerationConfig::update_generation_config(const ov::AnyMap& properties) { read_anymap_param(properties, "num_assistant_tokens", num_assistant_tokens); read_anymap_param(properties, "max_ngram_size", max_ngram_size); - // Structured Output generation + // Structured output read_anymap_param(properties, "structured_output_config", structured_output_config); } diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index 9200ea9347..b971780680 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -1330,6 +1330,12 @@ class PerfMetrics: :param get_detokenization_duration: Returns the mean and standard deviation of detokenization durations in milliseconds. :type get_detokenization_duration: MeanStdPair + :param get_grammar_compiler_init_times: Returns a map with the time to initialize the grammar compiler for each backend in milliseconds. + :type get_grammar_compiler_init_times: dict[str, float] + + :param get_grammar_compile_time: Returns the mean, standard deviation, min, and max of grammar compile times in milliseconds. + :type get_grammar_compile_time: SummaryStats + :param raw_metrics: A structure of RawPerfMetrics type that holds raw metrics. :type raw_metrics: RawPerfMetrics """ @@ -1345,7 +1351,7 @@ class PerfMetrics: ... def get_grammar_compile_time(self) -> SummaryStats: ... - def get_grammar_compiler_init_times(self) -> SummaryStats: + def get_grammar_compiler_init_times(self) -> dict[str, float]: ... def get_inference_duration(self) -> MeanStdPair: ... @@ -1463,6 +1469,9 @@ class RawPerfMetrics: :param inference_durations : Total inference duration for each generate call in milliseconds. :type batch_sizes: list[float] + + :param grammar_compile_times: Time to compile the grammar in milliseconds. + :type grammar_compile_times: list[float] """ def __init__(self) -> None: ... @@ -1473,6 +1482,9 @@ class RawPerfMetrics: def generate_durations(self) -> list[float]: ... @property + def grammar_compile_times(self) -> list[float]: + ... + @property def inference_durations(self) -> list[float]: ... @property @@ -1482,12 +1494,6 @@ class RawPerfMetrics: def m_durations(self) -> list[float]: ... @property - def m_grammar_compile_time(self) -> list[float]: - ... - @property - def m_grammar_init_time(self) -> list[float]: - ... - @property def m_new_token_times(self) -> list[float]: ... @property From 6ba73e18988dc7e265cbef8909e8b58302b20530 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Fri, 4 Jul 2025 20:22:19 +0200 Subject: [PATCH 09/10] skip structured output tests when build with -DENABLE_XGRAMMAR=OFF --- .github/workflows/linux.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index df1c641835..3ad1ec79d6 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -902,8 +902,7 @@ jobs: PYTHONPATH: "${{ env.BUILD_DIR }}:" run: | source ${{ env.OV_INSTALL_DIR }}/setupvars.sh - python3 -m pytest -v ${{ env.SRC_DIR }}/tests/python_tests/test_llm_pipeline.py - # python3 -m pytest -v ${{ env.SRC_DIR }}/tests/python_tests/test_structured_output.py + python3 -m pytest -v ${{ env.SRC_DIR }}/tests/python_tests/test_llm_pipeline.py -k "not test_perf_metrics_with_structured_output" Overall_Status: name: ci/gha_overall_status_linux From e2926bd1707b04d2cb9f2422e417d6df5eaf88c3 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Fri, 11 Jul 2025 11:00:22 +0200 Subject: [PATCH 10/10] resolve conflicts --- src/python/openvino_genai/py_openvino_genai.pyi | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index a48fce40a0..31c67e7786 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -546,6 +546,12 @@ class ExtendedPerfMetrics: :param get_detokenization_duration: Returns the mean and standard deviation of detokenization durations in milliseconds. :type get_detokenization_duration: MeanStdPair + :param get_grammar_compiler_init_times: Returns a map with the time to initialize the grammar compiler for each backend in milliseconds. + :type get_grammar_compiler_init_times: dict[str, float] + + :param get_grammar_compile_time: Returns the mean, standard deviation, min, and max of grammar compile times in milliseconds. + :type get_grammar_compile_time: SummaryStats + :param raw_metrics: A structure of RawPerfMetrics type that holds raw metrics. :type raw_metrics: RawPerfMetrics """