Skip to content

Commit 3e5cc12

Browse files
authored
[tests] remove tests from libraries with deprecated support (flax, tensorflow_text, ...) (#39051)
* rm tf/flax tests * more flax deletions * revert fixture change * reverted test that should not be deleted; rm tf/flax test * revert * fix a few add-model-like tests * fix add-model-like checkpoint source * a few more * test_get_model_files_only_pt fix * fix test_retrieve_info_for_model_with_xxx * fix test_retrieve_model_classes * relative paths are the devil * add todo
1 parent cfff7ca commit 3e5cc12

16 files changed

+156
-691
lines changed

src/transformers/commands/add_new_model_like.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def get_model_files(model_type: str, frameworks: Optional[list[str]] = None) ->
659659
return {"doc_file": doc_file, "model_files": model_files, "module_name": module_name, "test_files": test_files}
660660

661661

662-
_re_checkpoint_for_doc = re.compile(r"^_CHECKPOINT_FOR_DOC\s+=\s+(\S*)\s*$", flags=re.MULTILINE)
662+
_re_checkpoint_in_config = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)")
663663

664664

665665
def find_base_model_checkpoint(
@@ -680,13 +680,14 @@ def find_base_model_checkpoint(
680680
model_files = get_model_files(model_type)
681681
module_files = model_files["model_files"]
682682
for fname in module_files:
683-
if "modeling" not in str(fname):
683+
# After the @auto_docstring refactor, we expect the checkpoint to be in the configuration file's docstring
684+
if "configuration" not in str(fname):
684685
continue
685686

686687
with open(fname, "r", encoding="utf-8") as f:
687688
content = f.read()
688-
if _re_checkpoint_for_doc.search(content) is not None:
689-
checkpoint = _re_checkpoint_for_doc.search(content).groups()[0]
689+
if _re_checkpoint_in_config.search(content) is not None:
690+
checkpoint = _re_checkpoint_in_config.search(content).groups()[0]
690691
# Remove quotes
691692
checkpoint = checkpoint.replace('"', "")
692693
checkpoint = checkpoint.replace("'", "")

src/transformers/testing_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,10 @@ def require_jinja(test_case):
495495

496496

497497
def require_tf2onnx(test_case):
498+
logger.warning_once(
499+
"TensorFlow test-related code, including `require_tf2onnx`, is deprecated and will be removed in "
500+
"Transformers v4.55"
501+
)
498502
return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)
499503

500504

@@ -689,6 +693,10 @@ def require_tensorflow_probability(test_case):
689693
These tests are skipped when TensorFlow probability isn't installed.
690694
691695
"""
696+
logger.warning_once(
697+
"TensorFlow test-related code, including `require_tensorflow_probability`, is deprecated and will be "
698+
"removed in Transformers v4.55"
699+
)
692700
return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")(
693701
test_case
694702
)
@@ -715,6 +723,9 @@ def require_flax(test_case):
715723
"""
716724
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
717725
"""
726+
logger.warning_once(
727+
"JAX test-related code, including `require_flax`, is deprecated and will be removed in Transformers v4.55"
728+
)
718729
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
719730

720731

@@ -758,6 +769,10 @@ def require_tensorflow_text(test_case):
758769
Decorator marking a test that requires tensorflow_text. These tests are skipped when tensroflow_text isn't
759770
installed.
760771
"""
772+
logger.warning_once(
773+
"TensorFlow test-related code, including `require_tensorflow_text`, is deprecated and will be "
774+
"removed in Transformers v4.55"
775+
)
761776
return unittest.skipUnless(is_tensorflow_text_available(), "test requires tensorflow_text")(test_case)
762777

763778

tests/fixtures/add_distilbert_like_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
"tf",
1717
"flax"
1818
]
19-
}
19+
}

tests/models/tapas/test_tokenization_tapas.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
)
3434
from transformers.testing_utils import (
3535
require_pandas,
36-
require_tensorflow_probability,
3736
require_tokenizers,
3837
require_torch,
3938
slow,
@@ -140,41 +139,6 @@ def get_input_output_texts(self, tokenizer):
140139
output_text = "unwanted, running"
141140
return input_text, output_text
142141

143-
@require_tensorflow_probability
144-
@slow
145-
def test_tf_encode_plus_sent_to_model(self):
146-
from transformers import TF_MODEL_MAPPING, TOKENIZER_MAPPING
147-
148-
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(TF_MODEL_MAPPING, TOKENIZER_MAPPING)
149-
150-
tokenizers = self.get_tokenizers(do_lower_case=False)
151-
for tokenizer in tokenizers:
152-
with self.subTest(f"{tokenizer.__class__.__name__}"):
153-
if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
154-
self.skipTest(f"{tokenizer.__class__} is not in the MODEL_TOKENIZER_MAPPING")
155-
156-
config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
157-
config = config_class()
158-
159-
if config.is_encoder_decoder or config.pad_token_id is None:
160-
self.skipTest(reason="Model is an encoder-decoder or does not have a pad token id set")
161-
162-
model = model_class(config)
163-
164-
# Make sure the model contains at least the full vocabulary size in its embedding matrix
165-
self.assertGreaterEqual(model.config.vocab_size, len(tokenizer))
166-
167-
# Build sequence
168-
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
169-
sequence = " ".join(first_ten_tokens)
170-
table = self.get_table(tokenizer, length=0)
171-
encoded_sequence = tokenizer.encode_plus(table, sequence, return_tensors="tf")
172-
batch_encoded_sequence = tokenizer.batch_encode_plus(table, [sequence, sequence], return_tensors="tf")
173-
174-
# This should not fail
175-
model(encoded_sequence)
176-
model(batch_encoded_sequence)
177-
178142
def test_rust_and_python_full_tokenizers(self):
179143
if not self.test_rust_tokenizer:
180144
self.skipTest(reason="test_rust_tokenizer is set to False")

tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,6 @@ def check_vision_text_output_attention(
161161
(text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]),
162162
)
163163

164-
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
165-
diff = np.abs(a - b).max()
166-
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
167-
168164
def test_vision_text_dual_encoder_model(self):
169165
inputs_dict = self.prepare_config_and_inputs()
170166
self.check_vision_text_dual_encoder_model(**inputs_dict)

tests/models/wav2vec2/test_modeling_wav2vec2.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -813,12 +813,6 @@ def flatten_output(output):
813813
# (Even with this call, there are still memory leak by ~0.04MB)
814814
self.clear_torch_jit_class_registry()
815815

816-
@unittest.skip(
817-
"Need to investigate why config.do_stable_layer_norm is set to False here when it doesn't seem to be supported"
818-
)
819-
def test_flax_from_pt_safetensors(self):
820-
return
821-
822816

823817
@require_torch
824818
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):

tests/models/whisper/test_tokenization_whisper.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
2020
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
21-
from transformers.testing_utils import require_flax, require_torch, slow
21+
from transformers.testing_utils import require_torch, slow
2222

2323
from ...test_tokenization_common import TokenizerTesterMixin
2424

@@ -588,15 +588,6 @@ def test_convert_to_list_np(self):
588588
self.assertListEqual(WhisperTokenizer._convert_to_list(np_array), test_list)
589589
self.assertListEqual(WhisperTokenizerFast._convert_to_list(np_array), test_list)
590590

591-
@require_flax
592-
def test_convert_to_list_jax(self):
593-
import jax.numpy as jnp
594-
595-
test_list = [[1, 2, 3], [4, 5, 6]]
596-
jax_array = jnp.array(test_list)
597-
self.assertListEqual(WhisperTokenizer._convert_to_list(jax_array), test_list)
598-
self.assertListEqual(WhisperTokenizerFast._convert_to_list(jax_array), test_list)
599-
600591
@require_torch
601592
def test_convert_to_list_pt(self):
602593
import torch

tests/pipelines/test_pipelines_table_question_answering.py

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,10 @@
1919
AutoModelForTableQuestionAnswering,
2020
AutoTokenizer,
2121
TableQuestionAnsweringPipeline,
22-
TFAutoModelForTableQuestionAnswering,
2322
pipeline,
2423
)
2524
from transformers.testing_utils import (
2625
is_pipeline_test,
27-
require_pandas,
28-
require_tensorflow_probability,
2926
require_torch,
3027
slow,
3128
)
@@ -316,55 +313,6 @@ def test_integration_wtq_pt(self, torch_dtype="float32"):
316313
def test_integration_wtq_pt_fp16(self):
317314
self.test_integration_wtq_pt(torch_dtype="float16")
318315

319-
@slow
320-
@require_tensorflow_probability
321-
@require_pandas
322-
def test_integration_wtq_tf(self):
323-
model_id = "google/tapas-base-finetuned-wtq"
324-
model = TFAutoModelForTableQuestionAnswering.from_pretrained(model_id)
325-
tokenizer = AutoTokenizer.from_pretrained(model_id)
326-
table_querier = pipeline("table-question-answering", model=model, tokenizer=tokenizer)
327-
328-
data = {
329-
"Repository": ["Transformers", "Datasets", "Tokenizers"],
330-
"Stars": ["36542", "4512", "3934"],
331-
"Contributors": ["651", "77", "34"],
332-
"Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
333-
}
334-
queries = [
335-
"What repository has the largest number of stars?",
336-
"Given that the numbers of stars defines if a repository is active, what repository is the most active?",
337-
"What is the number of repositories?",
338-
"What is the average number of stars?",
339-
"What is the total amount of stars?",
340-
]
341-
342-
results = table_querier(data, queries)
343-
344-
expected_results = [
345-
{"answer": "Transformers", "coordinates": [(0, 0)], "cells": ["Transformers"], "aggregator": "NONE"},
346-
{"answer": "Transformers", "coordinates": [(0, 0)], "cells": ["Transformers"], "aggregator": "NONE"},
347-
{
348-
"answer": "COUNT > Transformers, Datasets, Tokenizers",
349-
"coordinates": [(0, 0), (1, 0), (2, 0)],
350-
"cells": ["Transformers", "Datasets", "Tokenizers"],
351-
"aggregator": "COUNT",
352-
},
353-
{
354-
"answer": "AVERAGE > 36542, 4512, 3934",
355-
"coordinates": [(0, 1), (1, 1), (2, 1)],
356-
"cells": ["36542", "4512", "3934"],
357-
"aggregator": "AVERAGE",
358-
},
359-
{
360-
"answer": "SUM > 36542, 4512, 3934",
361-
"coordinates": [(0, 1), (1, 1), (2, 1)],
362-
"cells": ["36542", "4512", "3934"],
363-
"aggregator": "SUM",
364-
},
365-
]
366-
self.assertListEqual(results, expected_results)
367-
368316
@slow
369317
@require_torch
370318
def test_integration_sqa_pt(self, torch_dtype="float32"):
@@ -395,34 +343,6 @@ def test_integration_sqa_pt(self, torch_dtype="float32"):
395343
def test_integration_sqa_pt_fp16(self):
396344
self.test_integration_sqa_pt(torch_dtype="float16")
397345

398-
@slow
399-
@require_tensorflow_probability
400-
@require_pandas
401-
def test_integration_sqa_tf(self):
402-
model_id = "google/tapas-base-finetuned-sqa"
403-
model = TFAutoModelForTableQuestionAnswering.from_pretrained(model_id)
404-
tokenizer = AutoTokenizer.from_pretrained(model_id)
405-
table_querier = pipeline(
406-
"table-question-answering",
407-
model=model,
408-
tokenizer=tokenizer,
409-
)
410-
data = {
411-
"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
412-
"Age": ["56", "45", "59"],
413-
"Number of movies": ["87", "53", "69"],
414-
"Date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
415-
}
416-
queries = ["How many movies has George Clooney played in?", "How old is he?", "What's his date of birth?"]
417-
results = table_querier(data, queries, sequential=True)
418-
419-
expected_results = [
420-
{"answer": "69", "coordinates": [(2, 2)], "cells": ["69"]},
421-
{"answer": "59", "coordinates": [(2, 1)], "cells": ["59"]},
422-
{"answer": "28 november 1967", "coordinates": [(2, 3)], "cells": ["28 november 1967"]},
423-
]
424-
self.assertListEqual(results, expected_results)
425-
426346
@slow
427347
@require_torch
428348
def test_large_model_pt_tapex(self, torch_dtype="float32"):

tests/test_image_transforms.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,13 @@
1717
import numpy as np
1818
from parameterized import parameterized
1919

20-
from transformers.testing_utils import require_flax, require_torch, require_vision
21-
from transformers.utils.import_utils import is_flax_available, is_torch_available, is_vision_available
20+
from transformers.testing_utils import require_torch, require_vision
21+
from transformers.utils.import_utils import is_torch_available, is_vision_available
2222

2323

2424
if is_torch_available():
2525
import torch
2626

27-
if is_flax_available():
28-
import jax
29-
3027
if is_vision_available():
3128
import PIL.Image
3229

@@ -133,21 +130,6 @@ def test_to_pil_image_from_torch(self):
133130
self.assertIsInstance(pil_image, PIL.Image.Image)
134131
self.assertEqual(pil_image.size, (5, 4))
135132

136-
@require_flax
137-
def test_to_pil_image_from_jax(self):
138-
key = jax.random.PRNGKey(0)
139-
# channel first
140-
image = jax.random.uniform(key, (3, 4, 5))
141-
pil_image = to_pil_image(image)
142-
self.assertIsInstance(pil_image, PIL.Image.Image)
143-
self.assertEqual(pil_image.size, (5, 4))
144-
145-
# channel last
146-
image = jax.random.uniform(key, (4, 5, 3))
147-
pil_image = to_pil_image(image)
148-
self.assertIsInstance(pil_image, PIL.Image.Image)
149-
self.assertEqual(pil_image.size, (5, 4))
150-
151133
def test_to_channel_dimension_format(self):
152134
# Test that function doesn't reorder if channel dim matches the input.
153135
image = np.random.rand(3, 4, 5)

tests/test_modeling_common.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2453,10 +2453,6 @@ def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_cla
24532453

24542454
return new_tf_outputs, new_pt_outputs
24552455

2456-
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
2457-
diff = np.abs(a - b).max()
2458-
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
2459-
24602456
def test_inputs_embeds(self):
24612457
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
24622458

0 commit comments

Comments
 (0)