diff --git a/tools/who_what_benchmark/whowhatbench/model_loaders.py b/tools/who_what_benchmark/whowhatbench/model_loaders.py index 50a28e011f..90f667c28b 100644 --- a/tools/who_what_benchmark/whowhatbench/model_loaders.py +++ b/tools/who_what_benchmark/whowhatbench/model_loaders.py @@ -234,9 +234,35 @@ def load_visual_text_model( model_id, trust_remote_code=True, device_map=device.lower() ) except ValueError: + from_pretrained_kwargs = {} + if config.model_type == "phi4mm": + if "activation_checkpointing" in config.audio_processor["config"]: + config.audio_processor["config"]["activation_checkpointing"] = "" + config._attn_implementation = "sdpa" + from_pretrained_kwargs["config"] = config + model = AutoModelForCausalLM.from_pretrained( - model_id, trust_remote_code=True, device_map=device.lower(), _attn_implementation="eager", use_flash_attention_2=False + model_id, + trust_remote_code=True, + device_map=device.lower(), + use_flash_attention_2=False, + **from_pretrained_kwargs, ) + + if config.model_type == "phi4mm": + use_lora = False + if hasattr(config, "vision_lora") and config.vision_lora is not None: + model.set_lora_adapter("vision") + use_lora = True + if hasattr(config, "speech_lora") and config.speech_lora is not None: + model.set_lora_adapter("speech") + use_lora = True + if use_lora: + model.unset_lora_adapter = lambda: None + model.set_lora_adapter = lambda _: None + if hasattr(model.model, "_require_grads_hook"): + model.model.disable_input_require_grads() + model.eval() try: model.get_vision_tower().load_model()