diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5e7be62342c3..905e5435ae91 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -22,12 +22,11 @@ import re import tempfile import traceback -import unittest -import unittest.mock as mock import uuid import warnings from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union +from unittest import mock import numpy as np import pytest @@ -210,16 +209,18 @@ def cast_maybe_tensor_dtype(maybe_tensor, current_dtype, target_dtype): return maybe_tensor -class ModelUtilsTest(unittest.TestCase): - def tearDown(self): - super().tearDown() +class TestModelUtils: + def teardown_method(self): + pass + + def test_missing_key_loading_warning_message(self, caplog): + import logging - def test_missing_key_loading_warning_message(self): - with self.assertLogs("diffusers.models.modeling_utils", level="WARNING") as logs: - UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet") + caplog.set_level(logging.WARNING, logger="diffusers.models.modeling_utils") + UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet") # make sure that error message states what keys are missing - assert "conv_out.bias" in " ".join(logs.output) + assert "conv_out.bias" in caplog.text @parameterized.expand( [ @@ -236,7 +237,7 @@ def load_model(path): kwargs["subfolder"] = subfolder return UNet2DConditionModel.from_pretrained(path, **kwargs) - with self.assertWarns(FutureWarning) as warning: + with pytest.warns(FutureWarning) as warning: if use_local: with tempfile.TemporaryDirectory() as tmpdirname: tmpdirname = snapshot_download(repo_id=repo_id) @@ -244,8 +245,8 @@ def load_model(path): else: _ = load_model(repo_id) - warning_message = str(warning.warnings[0].message) - self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_message) + warning_message = str(warning.list[0].message) + assert "This serialization format is now deprecated to standardize the serialization" in warning_message # Local tests are already covered down below. @parameterized.expand( @@ -306,7 +307,7 @@ def test_local_files_only_with_sharded_checkpoint(self): with mock.patch("requests.Session.get", return_value=error_response): # Should fail with local_files_only=False (network required) # We would make a network call with model_info - with self.assertRaises(OSError): + with pytest.raises(OSError): FluxTransformer2DModel.from_pretrained( repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False ) @@ -328,7 +329,7 @@ def test_local_files_only_with_sharded_checkpoint(self): os.remove(cached_shard_file) # Attempting to load from cache should raise an error - with self.assertRaises(OSError) as context: + with pytest.raises(OSError) as context: FluxTransformer2DModel.from_pretrained( repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True ) @@ -339,8 +340,8 @@ def test_local_files_only_with_sharded_checkpoint(self): f"Expected error about missing shard, got: {error_msg}" ) - @unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners") - @unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.") + @pytest.mark.skip(reason="Flaky behaviour on CI. Re-enable after migrating to new runners") + @pytest.mark.skipif(torch_device == "mps", reason="Test not supported for MPS.") def test_one_request_upon_cached(self): use_safetensors = False @@ -373,7 +374,7 @@ def test_one_request_upon_cached(self): ) def test_weight_overwrite(self): - with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: + with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(ValueError) as error_context: UNet2DConditionModel.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", @@ -414,9 +415,9 @@ def test_keep_modules_in_fp32(self): for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): if name in model._keep_in_fp32_modules: - self.assertTrue(module.weight.dtype == torch.float32) + assert module.weight.dtype == torch.float32 else: - self.assertTrue(module.weight.dtype == torch_dtype) + assert module.weight.dtype == torch_dtype def get_dummy_inputs(): batch_size = 2 @@ -466,9 +467,9 @@ def test_forward_with_norm_groups(self): if isinstance(output, dict): output = output.to_tuple()[0] - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" class ModelTesterMixin: @@ -478,6 +479,18 @@ class ModelTesterMixin: model_split_percents = [0.5, 0.7, 0.9] uses_custom_attn_processor = False + def get_init_dict(self): + raise NotImplementedError( + "You need to implement `get_init_dict(self)` in the child test class. " + "See existing pipeline tests for reference." + ) + + def get_dummy_inputs(self, device, seed=0): + raise NotImplementedError( + "You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. " + "See existing pipeline tests for reference." + ) + def check_device_map_is_respected(self, model, device_map): for param_name, param in model.named_parameters(): # Find device in device_map @@ -488,9 +501,9 @@ def check_device_map_is_respected(self, model, device_map): param_device = device_map[param_name] if param_device in ["cpu", "disk"]: - self.assertEqual(param.device, torch.device("meta")) + assert param.device == torch.device("meta") else: - self.assertEqual(param.device, torch.device(param_device)) + assert param.device == torch.device(param_device) def test_from_save_pretrained(self, expected_max_diff=5e-5): if self.forward_requires_fresh_args: @@ -529,7 +542,7 @@ def test_from_save_pretrained(self, expected_max_diff=5e-5): new_image = new_image.to_tuple()[0] max_diff = (image - new_image).abs().max().item() - self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") + assert max_diff <= expected_max_diff, "Models give different forward passes" def test_getattr_is_correct(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -563,18 +576,18 @@ def test_getattr_is_correct(self): assert cap_logger.out == "" # warning should be thrown - with self.assertWarns(FutureWarning): + with pytest.warns(FutureWarning): assert model.test_attribute == 5 - with self.assertWarns(FutureWarning): + with pytest.warns(FutureWarning): assert getattr(model, "test_attribute") == 5 - with self.assertRaises(AttributeError) as error: + with pytest.raises(AttributeError) as error: model.does_not_exist assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'" - @unittest.skipIf( + @pytest.mark.skipif( torch_device != "npu" or not is_torch_npu_available(), reason="torch npu flash attention is only available with NPU and `torch_npu` installed", ) @@ -621,7 +634,7 @@ def test_set_torch_npu_flash_attn_processor_determinism(self): assert torch.allclose(output, output_3, atol=self.base_precision) assert torch.allclose(output_2, output_3, atol=self.base_precision) - @unittest.skipIf( + @pytest.mark.skipif( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed", ) @@ -748,7 +761,7 @@ def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): new_model.set_default_attn_processor() # non-variant cannot be loaded - with self.assertRaises(OSError) as error_context: + with pytest.raises(OSError) as error_context: self.model_class.from_pretrained(tmpdirname) # make sure that error message states what keys are missing @@ -773,11 +786,11 @@ def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): new_image = new_image.to_tuple()[0] max_diff = (image - new_image).abs().max().item() - self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") + assert max_diff <= expected_max_diff, "Models give different forward passes" @is_torch_compile @require_torch_2 - @unittest.skipIf( + @pytest.mark.skipif( get_python_version == (3, 12), reason="Torch Dynamo isn't yet supported for Python 3.12.", ) @@ -839,7 +852,7 @@ def test_determinism(self, expected_max_diff=1e-5): out_1 = out_1[~np.isnan(out_1)] out_2 = out_2[~np.isnan(out_2)] max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, expected_max_diff) + assert max_diff <= expected_max_diff def test_output(self, expected_output_shape=None): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -853,16 +866,16 @@ def test_output(self, expected_output_shape=None): if isinstance(output, dict): output = output.to_tuple()[0] - self.assertIsNotNone(output) + assert output is not None # input & output have to have the same shape input_tensor = inputs_dict[self.main_input_name] if expected_output_shape is None: expected_shape = input_tensor.shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" else: - self.assertEqual(output.shape, expected_output_shape, "Input and output shapes do not match") + assert output.shape == expected_output_shape, "Input and output shapes do not match" def test_model_from_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -883,7 +896,7 @@ def test_model_from_pretrained(self): for param_name in model.state_dict().keys(): param_1 = model.state_dict()[param_name] param_2 = new_model.state_dict()[param_name] - self.assertEqual(param_1.shape, param_2.shape) + assert param_1.shape == param_2.shape with torch.no_grad(): output_1 = model(**inputs_dict) @@ -896,7 +909,7 @@ def test_model_from_pretrained(self): if isinstance(output_2, dict): output_2 = output_2.to_tuple()[0] - self.assertEqual(output_1.shape, output_2.shape) + assert output_1.shape == output_2.shape @require_torch_accelerator_with_training def test_training(self): @@ -955,16 +968,13 @@ def recursive_check(tuple_object, dict_object): elif tuple_object is None: return else: - self.assertTrue( - torch.allclose( - set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 - ), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" - f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." - ), + assert torch.allclose( + set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 + ), ( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" + f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." ) if self.forward_requires_fresh_args: @@ -996,15 +1006,15 @@ def test_enable_disable_gradient_checkpointing(self): # at init model should have gradient checkpointing disabled model = self.model_class(**init_dict) - self.assertFalse(model.is_gradient_checkpointing) + assert not model.is_gradient_checkpointing # check enable works model.enable_gradient_checkpointing() - self.assertTrue(model.is_gradient_checkpointing) + assert model.is_gradient_checkpointing # check disable works model.disable_gradient_checkpointing() - self.assertFalse(model.is_gradient_checkpointing) + assert not model.is_gradient_checkpointing @require_torch_accelerator_with_training def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}): @@ -1048,7 +1058,7 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_ loss_2.backward() # compare the output and parameters gradients - self.assertTrue((loss - loss_2).abs() < loss_tolerance) + assert (loss - loss_2).abs() < loss_tolerance named_params = dict(model.named_parameters()) named_params_2 = dict(model_2.named_parameters()) @@ -1061,9 +1071,9 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_ # It currently errors out the gradient checkpointing test because the gradients for attn2.to_out is None if param.grad is None: continue - self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol)) + assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol) - @unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.") + @pytest.mark.skipif(torch_device == "mps", reason="This test is not supported for MPS devices.") def test_gradient_checkpointing_is_applied( self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None ): @@ -1087,7 +1097,7 @@ def test_gradient_checkpointing_is_applied( modules_with_gc_enabled = {} for submodule in model.modules(): if hasattr(submodule, "gradient_checkpointing"): - self.assertTrue(submodule.gradient_checkpointing) + assert submodule.gradient_checkpointing modules_with_gc_enabled[submodule.__class__.__name__] = True assert set(modules_with_gc_enabled.keys()) == expected_set @@ -1115,7 +1125,7 @@ def test_deprecated_kwargs(self): @parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)]) @torch.no_grad() - @unittest.skipIf(not is_peft_available(), "Only with PEFT") + @pytest.mark.skipif(not is_peft_available(), reason="Only with PEFT") def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False): from peft import LoraConfig from peft.utils import get_peft_model_state_dict @@ -1139,21 +1149,21 @@ def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False): use_dora=use_dora, ) model.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" torch.manual_seed(0) outputs_with_lora = model(**inputs_dict, return_dict=False)[0] - self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)) + assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4) with tempfile.TemporaryDirectory() as tmpdir: model.save_lora_adapter(tmpdir) - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + assert os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) model.unload_lora() - self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly" model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0") @@ -1161,17 +1171,17 @@ def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False): for k in state_dict_loaded: loaded_v = state_dict_loaded[k] retrieved_v = state_dict_retrieved[k].to(loaded_v.device) - self.assertTrue(torch.allclose(loaded_v, retrieved_v)) + assert torch.allclose(loaded_v, retrieved_v) - self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" torch.manual_seed(0) outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] - self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) - self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) + assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4) + assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4) - @unittest.skipIf(not is_peft_available(), "Only with PEFT") + @pytest.mark.skipif(not is_peft_available(), reason="Only with PEFT") def test_lora_wrong_adapter_name_raises_error(self): from peft import LoraConfig @@ -1191,18 +1201,18 @@ def test_lora_wrong_adapter_name_raises_error(self): use_dora=False, ) model.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" with tempfile.TemporaryDirectory() as tmpdir: wrong_name = "foo" - with self.assertRaises(ValueError) as err_context: + with pytest.raises(ValueError) as err_context: model.save_lora_adapter(tmpdir, adapter_name=wrong_name) - self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)) + assert f"Adapter name {wrong_name} not found in the model." in str(err_context.exception) @parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)]) @torch.no_grad() - @unittest.skipIf(not is_peft_available(), "Only with PEFT") + @pytest.mark.skipif(not is_peft_available(), reason="Only with PEFT") def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora): from peft import LoraConfig @@ -1223,22 +1233,22 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_d ) model.add_adapter(denoiser_lora_config) metadata = model.peft_config["default"].to_dict() - self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" with tempfile.TemporaryDirectory() as tmpdir: model.save_lora_adapter(tmpdir) model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") - self.assertTrue(os.path.isfile(model_file)) + assert os.path.isfile(model_file) model.unload_lora() - self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly" model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) parsed_metadata = model.peft_config["default_0"].to_dict() check_if_dicts_are_equal(metadata, parsed_metadata) @torch.no_grad() - @unittest.skipIf(not is_peft_available(), "Only with PEFT") + @pytest.mark.skipif(not is_peft_available(), reason="Only with PEFT") def test_lora_adapter_wrong_metadata_raises_error(self): from peft import LoraConfig @@ -1259,12 +1269,12 @@ def test_lora_adapter_wrong_metadata_raises_error(self): use_dora=False, ) model.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" with tempfile.TemporaryDirectory() as tmpdir: model.save_lora_adapter(tmpdir) model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") - self.assertTrue(os.path.isfile(model_file)) + assert os.path.isfile(model_file) # Perturb the metadata in the state dict. loaded_state_dict = safetensors.torch.load_file(model_file) @@ -1278,11 +1288,11 @@ def test_lora_adapter_wrong_metadata_raises_error(self): safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata) model.unload_lora() - self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly" - with self.assertRaises(TypeError) as err_context: + with pytest.raises(TypeError) as err_context: model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) - self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception)) + assert "`LoraConfig` class could not be instantiated" in str(err_context.exception) @require_torch_accelerator def test_cpu_offload(self): @@ -1306,13 +1316,13 @@ def test_cpu_offload(self): max_memory = {0: max_size, "cpu": model_size * 2} new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) # Making sure part of the model will actually end up offloaded - self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) + assert set(new_model.hf_device_map.values()) == {0, "cpu"} self.check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0) new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + assert torch.allclose(base_output[0], new_output[0], atol=1e-5) @require_torch_accelerator def test_disk_offload_without_safetensors(self): @@ -1333,7 +1343,7 @@ def test_disk_offload_without_safetensors(self): with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir, safe_serialization=False) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # This errors out because it's missing an offload folder new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) @@ -1345,7 +1355,7 @@ def test_disk_offload_without_safetensors(self): torch.manual_seed(0) new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + assert torch.allclose(base_output[0], new_output[0], atol=1e-5) @require_torch_accelerator def test_disk_offload_with_safetensors(self): @@ -1373,7 +1383,7 @@ def test_disk_offload_with_safetensors(self): torch.manual_seed(0) new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + assert torch.allclose(base_output[0], new_output[0], atol=1e-5) @require_torch_multi_accelerator def test_model_parallelism(self): @@ -1397,14 +1407,14 @@ def test_model_parallelism(self): max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2} new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) # Making sure part of the model will actually end up offloaded - self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) + assert set(new_model.hf_device_map.values()) == {0, 1} self.check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0) new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + assert torch.allclose(base_output[0], new_output[0], atol=1e-5) @require_torch_accelerator def test_sharded_checkpoints(self): @@ -1419,14 +1429,14 @@ def test_sharded_checkpoints(self): max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") - self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) # Now check if the right number of shards exists. First, let's get the number of shards. # Since this number can be dependent on the model being tested, it's important that we calculate it # instead of hardcoding it. expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - self.assertTrue(actual_num_shards == expected_num_shards) + assert actual_num_shards == expected_num_shards new_model = self.model_class.from_pretrained(tmp_dir).eval() new_model = new_model.to(torch_device) @@ -1436,7 +1446,7 @@ def test_sharded_checkpoints(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + assert torch.allclose(base_output[0], new_output[0], atol=1e-5) @require_torch_accelerator def test_sharded_checkpoints_with_variant(self): @@ -1457,14 +1467,14 @@ def test_sharded_checkpoints_with_variant(self): model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant) index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) - self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename))) + assert os.path.exists(os.path.join(tmp_dir, index_filename)) # Now check if the right number of shards exists. First, let's get the number of shards. # Since this number can be dependent on the model being tested, it's important that we calculate it # instead of hardcoding it. expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_filename)) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - self.assertTrue(actual_num_shards == expected_num_shards) + assert actual_num_shards == expected_num_shards new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval() new_model = new_model.to(torch_device) @@ -1474,7 +1484,7 @@ def test_sharded_checkpoints_with_variant(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + assert torch.allclose(base_output[0], new_output[0], atol=1e-5) @require_torch_accelerator def test_sharded_checkpoints_with_parallel_loading(self): @@ -1489,14 +1499,14 @@ def test_sharded_checkpoints_with_parallel_loading(self): max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") - self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) # Now check if the right number of shards exists. First, let's get the number of shards. # Since this number can be dependent on the model being tested, it's important that we calculate it # instead of hardcoding it. expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - self.assertTrue(actual_num_shards == expected_num_shards) + assert actual_num_shards == expected_num_shards # Load with parallel loading os.environ["HF_ENABLE_PARALLEL_LOADING"] = "yes" @@ -1507,7 +1517,7 @@ def test_sharded_checkpoints_with_parallel_loading(self): if "generator" in inputs_dict: _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + assert torch.allclose(base_output[0], new_output[0], atol=1e-5) # set to no. os.environ["HF_ENABLE_PARALLEL_LOADING"] = "no" @@ -1526,14 +1536,14 @@ def test_sharded_checkpoints_device_map(self): max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") - self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) # Now check if the right number of shards exists. First, let's get the number of shards. # Since this number can be dependent on the model being tested, it's important that we calculate it # instead of hardcoding it. expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - self.assertTrue(actual_num_shards == expected_num_shards) + assert actual_num_shards == expected_num_shards new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto") @@ -1541,7 +1551,7 @@ def test_sharded_checkpoints_device_map(self): if "generator" in inputs_dict: _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + assert torch.allclose(base_output[0], new_output[0], atol=1e-5) # This test is okay without a GPU because we're not running any execution. We're just serializing # and check if the resultant files are following an expected format. @@ -1559,14 +1569,14 @@ def test_variant_sharded_ckpt_right_format(self): tmp_dir, variant=variant, max_shard_size=f"{max_shard_size}KB", safe_serialization=use_safe ) index_variant = _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safe else WEIGHTS_INDEX_NAME, variant) - self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_variant))) + assert os.path.exists(os.path.join(tmp_dir, index_variant)) # Now check if the right number of shards exists. First, let's get the number of shards. # Since this number can be dependent on the model being tested, it's important that we calculate it # instead of hardcoding it. expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_variant)) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(extension)]) - self.assertTrue(actual_num_shards == expected_num_shards) + assert actual_num_shards == expected_num_shards # Check if the variant is present as a substring in the checkpoints. shard_files = [ @@ -1634,9 +1644,9 @@ def check_linear_dtype(module, storage_dtype, compute_dtype): if any(re.search(pattern, name) for pattern in patterns_to_check): dtype_to_check = compute_dtype if getattr(submodule, "weight", None) is not None: - self.assertEqual(submodule.weight.dtype, dtype_to_check) + assert submodule.weight.dtype == dtype_to_check if getattr(submodule, "bias", None) is not None: - self.assertEqual(submodule.bias.dtype, dtype_to_check) + assert submodule.bias.dtype == dtype_to_check def test_layerwise_casting(storage_dtype, compute_dtype): torch.manual_seed(0) @@ -1651,7 +1661,7 @@ def test_layerwise_casting(storage_dtype, compute_dtype): # The precision test is not very important for fast tests. In most cases, the outputs will not be the same. # We just want to make sure that the layerwise casting is working as expected. - self.assertTrue(numpy_cosine_similarity_distance(base_slice, output) < 1.0) + assert numpy_cosine_similarity_distance(base_slice, output) < 1.0 test_layerwise_casting(torch.float16, torch.float32) test_layerwise_casting(torch.float8_e4m3fn, torch.float32) @@ -1692,15 +1702,15 @@ def get_memory_usage(storage_dtype, compute_dtype): ) compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None - self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint) + assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it. if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY: - self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory) + assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few # bytes. This only happens for some models, so we allow a small tolerance. # For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32. - self.assertTrue( + assert ( fp8_e4m3_fp32_max_memory < fp32_max_memory or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE ) @@ -1716,12 +1726,10 @@ def test_group_offloading(self, record_stream): @torch.no_grad() def run_forward(model): - self.assertTrue( - all( - module._diffusers_hook.get_hook("group_offloading") is not None - for module in model.modules() - if hasattr(module, "_diffusers_hook") - ) + assert all( + module._diffusers_hook.get_hook("group_offloading") is not None + for module in model.modules() + if hasattr(module, "_diffusers_hook") ) model.eval() return model(**inputs_dict)[0] @@ -1753,10 +1761,10 @@ def run_forward(model): ) output_with_group_offloading4 = run_forward(model) - self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) - self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)) - self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)) - self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)) + assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5) + assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5) + assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5) + assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5) @parameterized.expand([(False, "block_level"), (True, "leaf_level")]) @require_torch_accelerator @@ -1838,7 +1846,7 @@ def _run_forward(model, inputs_dict): **additional_kwargs, ) has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") - self.assertTrue(has_safetensors, "No safetensors found in the directory.") + assert has_safetensors, "No safetensors found in the directory." # For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic # in nature. So, skip it. @@ -1856,7 +1864,7 @@ def _run_forward(model, inputs_dict): raise ValueError(f"Following files are missing: {', '.join(missing_files)}") output_with_group_offloading = _run_forward(model, inputs_dict) - self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol)) + assert torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol) def test_auto_model(self, expected_max_diff=5e-5): if self.forward_requires_fresh_args: @@ -1895,10 +1903,8 @@ def test_auto_model(self, expected_max_diff=5e-5): output_auto = output_auto.to_tuple()[0] max_diff = (output_original - output_auto).abs().max().item() - self.assertLessEqual( - max_diff, - expected_max_diff, - f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}", + assert max_diff <= expected_max_diff, ( + f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}" ) @parameterized.expand( @@ -1912,7 +1918,7 @@ def test_wrong_device_map_raises_error(self, device_map, msg_substring): model = self.model_class(**init_dict) with tempfile.TemporaryDirectory() as tmpdir: model.save_pretrained(tmpdir) - with self.assertRaises(ValueError) as err_ctx: + with pytest.raises(ValueError) as err_ctx: _ = self.model_class.from_pretrained(tmpdir, device_map=device_map) assert msg_substring in str(err_ctx.exception) @@ -1942,7 +1948,7 @@ def test_passing_dict_device_map_works(self, name, device): @is_staging_test -class ModelPushToHubTester(unittest.TestCase): +class TestModelPushToHub: identifier = uuid.uuid4() repo_id = f"test-model-{identifier}" org_repo_id = f"valid_org/{repo_id}-org" @@ -1962,7 +1968,7 @@ def test_push_to_hub(self): new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}") for p1, p2 in zip(model.parameters(), new_model.parameters()): - self.assertTrue(torch.equal(p1, p2)) + assert torch.equal(p1, p2) # Reset repo delete_repo(token=TOKEN, repo_id=self.repo_id) @@ -1973,7 +1979,7 @@ def test_push_to_hub(self): new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}") for p1, p2 in zip(model.parameters(), new_model.parameters()): - self.assertTrue(torch.equal(p1, p2)) + assert torch.equal(p1, p2) # Reset repo delete_repo(self.repo_id, token=TOKEN) @@ -1993,7 +1999,7 @@ def test_push_to_hub_in_organization(self): new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id) for p1, p2 in zip(model.parameters(), new_model.parameters()): - self.assertTrue(torch.equal(p1, p2)) + assert torch.equal(p1, p2) # Reset repo delete_repo(token=TOKEN, repo_id=self.org_repo_id) @@ -2004,12 +2010,12 @@ def test_push_to_hub_in_organization(self): new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id) for p1, p2 in zip(model.parameters(), new_model.parameters()): - self.assertTrue(torch.equal(p1, p2)) + assert torch.equal(p1, p2) # Reset repo delete_repo(self.org_repo_id, token=TOKEN) - @unittest.skipIf( + @pytest.mark.skipif( not is_jinja_available(), reason="Model card tests cannot be performed without Jinja installed.", ) @@ -2296,7 +2302,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_ # check error when not passing valid adapter name name = "does-not-exist" msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name" - with self.assertRaisesRegex(ValueError, msg): + with pytest.raises(ValueError, match=msg): model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa @@ -2357,10 +2363,10 @@ def test_enable_lora_hotswap_called_after_adapter_added_raises(self): model.add_adapter(lora_config) msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.") - with self.assertRaisesRegex(RuntimeError, msg): + with pytest.raises(RuntimeError, match=msg): model.enable_lora_hotswap(target_rank=32) - def test_enable_lora_hotswap_called_after_adapter_added_warning(self): + def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog): # ensure that enable_lora_hotswap is called before loading the first adapter from diffusers.loaders.peft import logger @@ -2371,9 +2377,11 @@ def test_enable_lora_hotswap_called_after_adapter_added_warning(self): msg = ( "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." ) - with self.assertLogs(logger=logger, level="WARNING") as cm: + import logging + + with caplog.at_level(logging.WARNING, logger=logger.name): model.enable_lora_hotswap(target_rank=32, check_compiled="warn") - assert any(msg in log for log in cm.output) + assert any(msg in record.message for record in caplog.records) def test_enable_lora_hotswap_called_after_adapter_added_ignore(self): # check possibility to ignore the error/warning @@ -2384,7 +2392,7 @@ def test_enable_lora_hotswap_called_after_adapter_added_ignore(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") # Capture all warnings model.enable_lora_hotswap(target_rank=32, check_compiled="warn") - self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}") + assert len(w) == 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}" def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): # check that wrong argument value raises an error @@ -2393,22 +2401,24 @@ def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): model = self.model_class(**init_dict).to(torch_device) model.add_adapter(lora_config) msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.") - with self.assertRaisesRegex(ValueError, msg): + with pytest.raises(ValueError, match=msg): model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") - def test_hotswap_second_adapter_targets_more_layers_raises(self): + def test_hotswap_second_adapter_targets_more_layers_raises(self, caplog): # check the error and log from diffusers.loaders.peft import logger # at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers target_modules0 = ["to_q"] target_modules1 = ["to_q", "to_k"] - with self.assertRaises(RuntimeError): # peft raises RuntimeError - with self.assertLogs(logger=logger, level="ERROR") as cm: + with pytest.raises(RuntimeError): # peft raises RuntimeError + import logging + + with caplog.at_level(logging.ERROR, logger=logger.name): self.check_model_hotswap( do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1 ) - assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output) + assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) @require_torch_version_greater("2.7.1")