diff --git a/tools/llm_bench/benchmark.py b/tools/llm_bench/benchmark.py index 1f2abfbc7f..8a4e5f8550 100644 --- a/tools/llm_bench/benchmark.py +++ b/tools/llm_bench/benchmark.py @@ -148,7 +148,12 @@ def get_argprser(): parser.add_argument('--lora_alphas', nargs='*', help='Alphas params for LoRA adapters.', required=False, default=[]) parser.add_argument("--lora_mode", choices=["auto", "fuse", "static", "static_rank", "dynamic"], help="LoRA adapters loading mode") parser.add_argument("--empty_lora", action="store_true", help="Inference without lora") - parser.add_argument("--use_cb", action="store_true", help="Use Continuous Batching inference mode") + parser.add_argument( + "--use_cb", + action="store_true", + help='Deprecated, will be removed soon! Continues batching mode is used by default. ' + 'To switch to SPDA mode, please, set up {"ATTENTION_BACKEND": "SDPA"} in --load_config.' + ) parser.add_argument("--cb_config", required=False, default=None, help="Path to file with Continuous Batching Scheduler settings or dict") parser.add_argument("--draft_model", required=False, default=None, help="Path to draft model folder including IR files for Speculative decoding generation") diff --git a/tools/llm_bench/llm_bench_utils/config_class.py b/tools/llm_bench/llm_bench_utils/config_class.py index 64c7256c96..d2bec4e6c7 100644 --- a/tools/llm_bench/llm_bench_utils/config_class.py +++ b/tools/llm_bench/llm_bench_utils/config_class.py @@ -134,3 +134,6 @@ "text2img": "text-to-image", "inpainting": "inpainting" } + +PA_ATTENTION_BACKEND = "PA" +SDPA_ATTENTION_BACKEND = "SDPA" diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index efc18ff20f..873917dbc6 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -5,10 +5,15 @@ import json import logging as log from pathlib import Path -from llm_bench_utils.config_class import DEFAULT_MODEL_CLASSES, USE_CASES, OV_MODEL_CLASSES_MAPPING, PT_MODEL_CLASSES_MAPPING +from llm_bench_utils.config_class import ( + DEFAULT_MODEL_CLASSES, + USE_CASES, + OV_MODEL_CLASSES_MAPPING, + PT_MODEL_CLASSES_MAPPING, + PA_ATTENTION_BACKEND +) import librosa - KNOWN_PRECISIONS = [ 'FP32', 'FP16', 'FP16-INT8', 'INT8', 'INT8_compressed_weights', 'INT8_quantized', 'PT_compressed_weights', @@ -149,12 +154,6 @@ def analyze_args(args): model_args['lora_alphas'] = args.lora_alphas model_args['lora_mode'] = args.lora_mode model_args['empty_lora'] = args.empty_lora - use_cb = args.use_cb or args.draft_model - if args.device == "NPU" and use_cb: - log.warning("Continious batching and Speculative Decoding are not supported for NPU device") - use_cb = False - args.draft_model = None - model_args["use_cb"] = use_cb model_args['devices'] = args.device model_args['prompt_index'] = [] if args.prompt_index is not None else None if model_args['prompt_index'] is not None: @@ -181,18 +180,21 @@ def analyze_args(args): model_args['config'] = config if model_framework == 'ov': set_default_param_for_ov_config(model_args['config']) + if 'ATTENTION_BACKEND' not in model_args['config'] and use_case in ['text_gen', 'vlm'] and args.device != "NPU" and not optimum: + model_args['config']['ATTENTION_BACKEND'] = PA_ATTENTION_BACKEND log.info(f"OV Config={model_args['config']}") elif model_framework == 'pt': log.info(f"PT Config={model_args['config']}") model_args['model_type'] = get_model_type(model_name, use_case, model_framework) model_args['model_name'] = model_name - if use_cb and optimum: - raise RuntimeError("Continuous batching mode supported only via OpenVINO GenAI") cb_config = None if args.cb_config: cb_config = get_config(args.cb_config) model_args["cb_config"] = cb_config + if args.draft_model and (args.device == "NPU" or model_args['config']['ATTENTION_BACKEND'] != PA_ATTENTION_BACKEND): + log.warning("Speculative Decoding is supported only with Page Attention Backend and not supported for NPU device") + args.draft_model = None model_args['draft_model'] = args.draft_model model_args['draft_device'] = args.draft_device draft_cb_config = None diff --git a/tools/llm_bench/llm_bench_utils/ov_utils.py b/tools/llm_bench/llm_bench_utils/ov_utils.py index 91ea276dc6..1c0867023f 100644 --- a/tools/llm_bench/llm_bench_utils/ov_utils.py +++ b/tools/llm_bench/llm_bench_utils/ov_utils.py @@ -18,7 +18,8 @@ DEFAULT_MODEL_CLASSES, IMAGE_GEN_CLS, INPAINTING_IMAGE_GEN_CLS, - IMAGE_TO_IMAGE_GEN_CLS + IMAGE_TO_IMAGE_GEN_CLS, + PA_ATTENTION_BACKEND ) from transformers import pipeline import queue @@ -191,7 +192,7 @@ def create_genai_text_gen_model(model_path, device, ov_config, memory_monitor, * tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) draft_model_path = kwargs.get("draft_model", '') - cb = kwargs.get("use_cb", False) + cb = ov_config.get('ATTENTION_BACKEND', '') == PA_ATTENTION_BACKEND cb_config = kwargs.get("cb_config") use_streamer_metrics = False if cb or cb_config is not None or draft_model_path: @@ -599,7 +600,7 @@ def create_genai_image_text_gen_model(model_path, device, ov_config, memory_moni processor_config = get_vlm_processor(model_path) - cb = kwargs.get("use_cb", False) + cb = ov_config.get('ATTENTION_BACKEND', '') == PA_ATTENTION_BACKEND cb_config = kwargs.get("cb_config") if cb or cb_config is not None: log.info("Continuous Batching mode activated") diff --git a/tools/llm_bench/task/visual_language_generation.py b/tools/llm_bench/task/visual_language_generation.py index 3b782e0229..df1e56c369 100644 --- a/tools/llm_bench/task/visual_language_generation.py +++ b/tools/llm_bench/task/visual_language_generation.py @@ -168,14 +168,6 @@ def run_visual_language_generation_optimum( log.warning(f"[{num}] Prompt[{prompt_index}]'s md5 {result_md5_list} " f"is different from md5 of the {num - 1} iteration {prev_md5}") metrics_print.print_generated(num, warm_up=(num == 0), generated=generated_text[0], prompt_idx=prompt_index) - if not args.get("use_cb", False): - if num == 1: - # if the device is CPU, throw exception - if args['devices'].lower().startswith('cpu') is True: - assert (result_md5_list == prev_md5) - else: - # throw exception - assert (result_md5_list == prev_md5) else: metrics_print.print_generated(num, warm_up=(num == 0), generated=generated_text[0], prompt_idx=prompt_index) if bench_hook is not None: