Skip to content

Commit dcefb13

Browse files
committed
BREAKING: When using get_file, with extract=True or untar=True, the return value will be the path of the extracted directory, rather than the path of the archive.
1 parent 829c9aa commit dcefb13

File tree

2 files changed

+66
-51
lines changed

2 files changed

+66
-51
lines changed

keras/src/utils/file_utils.py

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,18 @@ def get_file(
163163
```
164164
165165
Args:
166-
fname: Name of the file. If an absolute path, e.g. `"/path/to/file.txt"`
167-
is specified, the file will be saved at that location.
166+
fname: If the target is a single file, this is your desired
167+
local name for the file.
168168
If `None`, the name of the file at `origin` will be used.
169+
If downloading and extracting a directory archive,
170+
the provided `fname` will be used as extraction directory
171+
name (only if it doesn't have an extension).
169172
origin: Original URL of the file.
170173
untar: Deprecated in favor of `extract` argument.
171-
boolean, whether the file should be decompressed
174+
Boolean, whether the file is a tar archive that should
175+
be extracted.
172176
md5_hash: Deprecated in favor of `file_hash` argument.
173-
md5 hash of the file for verification
177+
md5 hash of the file for file integrity verification.
174178
file_hash: The expected hash string of the file after download.
175179
The sha256 and md5 hash algorithms are both supported.
176180
cache_subdir: Subdirectory under the Keras cache dir where the file is
@@ -179,7 +183,8 @@ def get_file(
179183
hash_algorithm: Select the hash algorithm to verify the file.
180184
options are `"md5'`, `"sha256'`, and `"auto'`.
181185
The default 'auto' detects the hash algorithm in use.
182-
extract: True tries extracting the file as an Archive, like tar or zip.
186+
extract: If `True`, extracts the archive. Only applicable to compressed
187+
archive files like tar or zip.
183188
archive_format: Archive format to try for extracting the file.
184189
Options are `"auto'`, `"tar'`, `"zip'`, and `None`.
185190
`"tar"` includes tar, tar.gz, and tar.bz files.
@@ -219,36 +224,50 @@ def get_file(
219224
datadir = os.path.join(datadir_base, cache_subdir)
220225
os.makedirs(datadir, exist_ok=True)
221226

227+
provided_fname = fname
222228
fname = path_to_string(fname)
229+
223230
if not fname:
224231
fname = os.path.basename(urllib.parse.urlsplit(origin).path)
225232
if not fname:
226233
raise ValueError(
227234
"Can't parse the file name from the origin provided: "
228235
f"'{origin}'."
229-
"Please specify the `fname` as the input param."
236+
"Please specify the `fname` argument."
237+
)
238+
else:
239+
if os.sep in fname:
240+
raise ValueError(
241+
"Paths are no longer accepted as the `fname` argument. "
242+
"To specify the file's parent directory, use "
243+
f"the `cache_dir` argument. Received: fname={fname}"
230244
)
231245

232-
if untar:
233-
if fname.endswith(".tar.gz"):
234-
fname = pathlib.Path(fname)
235-
# The 2 `.with_suffix()` are because of `.tar.gz` as pathlib
236-
# considers it as 2 suffixes.
237-
fname = fname.with_suffix("").with_suffix("")
238-
fname = str(fname)
239-
untar_fpath = os.path.join(datadir, fname)
240-
fpath = untar_fpath + ".tar.gz"
246+
if extract or untar:
247+
if provided_fname:
248+
if "." in fname:
249+
download_target = os.path.join(datadir, fname)
250+
fname = fname[: fname.find(".")]
251+
extraction_dir = os.path.join(datadir, fname + "_extracted")
252+
else:
253+
extraction_dir = os.path.join(datadir, fname)
254+
download_target = os.path.join(datadir, fname + "_archive")
255+
else:
256+
extraction_dir = os.path.join(datadir, fname)
257+
download_target = os.path.join(datadir, fname + "_archive")
241258
else:
242-
fpath = os.path.join(datadir, fname)
259+
download_target = os.path.join(datadir, fname)
243260

244261
if force_download:
245262
download = True
246-
elif os.path.exists(fpath):
263+
elif os.path.exists(download_target):
247264
# File found in cache.
248265
download = False
249266
# Verify integrity if a hash was provided.
250267
if file_hash is not None:
251-
if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
268+
if not validate_file(
269+
download_target, file_hash, algorithm=hash_algorithm
270+
):
252271
io_utils.print_msg(
253272
"A local file was found, but it seems to be "
254273
f"incomplete or outdated because the {hash_algorithm} "
@@ -288,43 +307,42 @@ def __call__(self, block_num, block_size, total_size):
288307
error_msg = "URL fetch failure on {}: {} -- {}"
289308
try:
290309
try:
291-
urlretrieve(origin, fpath, DLProgbar())
310+
urlretrieve(origin, download_target, DLProgbar())
292311
except urllib.error.HTTPError as e:
293312
raise Exception(error_msg.format(origin, e.code, e.msg))
294313
except urllib.error.URLError as e:
295314
raise Exception(error_msg.format(origin, e.errno, e.reason))
296315
except (Exception, KeyboardInterrupt):
297-
if os.path.exists(fpath):
298-
os.remove(fpath)
316+
if os.path.exists(download_target):
317+
os.remove(download_target)
299318
raise
300319

301320
# Validate download if succeeded and user provided an expected hash
302321
# Security conscious users would get the hash of the file from a
303322
# separate channel and pass it to this API to prevent MITM / corruption:
304-
if os.path.exists(fpath) and file_hash is not None:
305-
if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
323+
if os.path.exists(download_target) and file_hash is not None:
324+
if not validate_file(
325+
download_target, file_hash, algorithm=hash_algorithm
326+
):
306327
raise ValueError(
307328
"Incomplete or corrupted file detected. "
308329
f"The {hash_algorithm} "
309330
"file hash does not match the provided value "
310331
f"of {file_hash}."
311332
)
312333

313-
if untar:
314-
if not os.path.exists(untar_fpath):
315-
status = extract_archive(fpath, datadir, archive_format="tar")
316-
if not status:
317-
warnings.warn("Could not extract archive.", stacklevel=2)
318-
return untar_fpath
334+
if extract or untar:
335+
if untar:
336+
archive_format = "tar"
319337

320-
if extract:
321-
status = extract_archive(fpath, datadir, archive_format)
338+
status = extract_archive(
339+
download_target, extraction_dir, archive_format
340+
)
322341
if not status:
323342
warnings.warn("Could not extract archive.", stacklevel=2)
343+
return extraction_dir
324344

325-
# TODO: return extracted fpath if we extracted an archive,
326-
# rather than the archive path.
327-
return fpath
345+
return download_target
328346

329347

330348
def resolve_hasher(algorithm, file_hash=None):

keras/src/utils/file_utils_test.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def test_valid_tar_extraction(self):
319319
"""Test valid tar.gz extraction and hash validation."""
320320
dest_dir = self.get_temp_dir()
321321
orig_dir = self.get_temp_dir()
322-
text_file_path, tar_file_path = self._create_tar_file(orig_dir)
322+
_, tar_file_path = self._create_tar_file(orig_dir)
323323
self._test_file_extraction_and_validation(
324324
dest_dir, tar_file_path, "tar.gz"
325325
)
@@ -328,7 +328,7 @@ def test_valid_zip_extraction(self):
328328
"""Test valid zip extraction and hash validation."""
329329
dest_dir = self.get_temp_dir()
330330
orig_dir = self.get_temp_dir()
331-
text_file_path, zip_file_path = self._create_zip_file(orig_dir)
331+
_, zip_file_path = self._create_zip_file(orig_dir)
332332
self._test_file_extraction_and_validation(
333333
dest_dir, zip_file_path, "zip"
334334
)
@@ -348,7 +348,7 @@ def test_get_file_with_tgz_extension(self):
348348
"""Test extraction of file with .tar.gz extension."""
349349
dest_dir = self.get_temp_dir()
350350
orig_dir = dest_dir
351-
text_file_path, tar_file_path = self._create_tar_file(orig_dir)
351+
_, tar_file_path = self._create_tar_file(orig_dir)
352352

353353
origin = urllib.parse.urljoin(
354354
"file://",
@@ -358,8 +358,8 @@ def test_get_file_with_tgz_extension(self):
358358
path = file_utils.get_file(
359359
"test.txt.tar.gz", origin, untar=True, cache_subdir=dest_dir
360360
)
361-
self.assertTrue(path.endswith(".txt"))
362361
self.assertTrue(os.path.exists(path))
362+
self.assertTrue(os.path.exists(os.path.join(path, "test.txt")))
363363

364364
def test_get_file_with_integrity_check(self):
365365
"""Test file download with integrity check."""
@@ -459,7 +459,7 @@ def _create_tar_file(self, directory):
459459
text_file.write("Float like a butterfly, sting like a bee.")
460460

461461
with tarfile.open(tar_file_path, "w:gz") as tar_file:
462-
tar_file.add(text_file_path)
462+
tar_file.add(text_file_path, arcname="test.txt")
463463

464464
return text_file_path, tar_file_path
465465

@@ -471,7 +471,7 @@ def _create_zip_file(self, directory):
471471
text_file.write("Float like a butterfly, sting like a bee.")
472472

473473
with zipfile.ZipFile(zip_file_path, "w") as zip_file:
474-
zip_file.write(text_file_path)
474+
zip_file.write(text_file_path, arcname="test.txt")
475475

476476
return text_file_path, zip_file_path
477477

@@ -484,7 +484,6 @@ def _test_file_extraction_and_validation(
484484
urllib.request.pathname2url(os.path.abspath(file_path)),
485485
)
486486

487-
hashval_sha256 = file_utils.hash_file(file_path)
488487
hashval_md5 = file_utils.hash_file(file_path, algorithm="md5")
489488

490489
if archive_type:
@@ -499,17 +498,15 @@ def _test_file_extraction_and_validation(
499498
extract=extract,
500499
cache_subdir=dest_dir,
501500
)
502-
path = file_utils.get_file(
503-
"test",
504-
origin,
505-
file_hash=hashval_sha256,
506-
extract=extract,
507-
cache_subdir=dest_dir,
508-
)
501+
if extract:
502+
fpath = path + "_archive"
503+
else:
504+
fpath = path
505+
509506
self.assertTrue(os.path.exists(path))
510-
self.assertTrue(file_utils.validate_file(path, hashval_sha256))
511-
self.assertTrue(file_utils.validate_file(path, hashval_md5))
512-
os.remove(path)
507+
self.assertTrue(file_utils.validate_file(fpath, hashval_md5))
508+
if extract:
509+
self.assertTrue(os.path.exists(os.path.join(path, "test.txt")))
513510

514511
def test_exists(self):
515512
temp_dir = self.get_temp_dir()

0 commit comments

Comments
 (0)