@@ -163,14 +163,18 @@ def get_file(
163
163
```
164
164
165
165
Args:
166
- fname: Name of the file. If an absolute path, e.g. `"/path/to/file.txt"`
167
- is specified, the file will be saved at that location .
166
+ fname: If the target is a single file, this is your desired
167
+ local name for the file.
168
168
If `None`, the name of the file at `origin` will be used.
169
+ If downloading and extracting a directory archive,
170
+ the provided `fname` will be used as extraction directory
171
+ name (only if it doesn't have an extension).
169
172
origin: Original URL of the file.
170
173
untar: Deprecated in favor of `extract` argument.
171
- boolean, whether the file should be decompressed
174
+ Boolean, whether the file is a tar archive that should
175
+ be extracted.
172
176
md5_hash: Deprecated in favor of `file_hash` argument.
173
- md5 hash of the file for verification
177
+ md5 hash of the file for file integrity verification.
174
178
file_hash: The expected hash string of the file after download.
175
179
The sha256 and md5 hash algorithms are both supported.
176
180
cache_subdir: Subdirectory under the Keras cache dir where the file is
@@ -179,7 +183,8 @@ def get_file(
179
183
hash_algorithm: Select the hash algorithm to verify the file.
180
184
options are `"md5'`, `"sha256'`, and `"auto'`.
181
185
The default 'auto' detects the hash algorithm in use.
182
- extract: True tries extracting the file as an Archive, like tar or zip.
186
+ extract: If `True`, extracts the archive. Only applicable to compressed
187
+ archive files like tar or zip.
183
188
archive_format: Archive format to try for extracting the file.
184
189
Options are `"auto'`, `"tar'`, `"zip'`, and `None`.
185
190
`"tar"` includes tar, tar.gz, and tar.bz files.
@@ -219,36 +224,50 @@ def get_file(
219
224
datadir = os .path .join (datadir_base , cache_subdir )
220
225
os .makedirs (datadir , exist_ok = True )
221
226
227
+ provided_fname = fname
222
228
fname = path_to_string (fname )
229
+
223
230
if not fname :
224
231
fname = os .path .basename (urllib .parse .urlsplit (origin ).path )
225
232
if not fname :
226
233
raise ValueError (
227
234
"Can't parse the file name from the origin provided: "
228
235
f"'{ origin } '."
229
- "Please specify the `fname` as the input param."
236
+ "Please specify the `fname` argument."
237
+ )
238
+ else :
239
+ if os .sep in fname :
240
+ raise ValueError (
241
+ "Paths are no longer accepted as the `fname` argument. "
242
+ "To specify the file's parent directory, use "
243
+ f"the `cache_dir` argument. Received: fname={ fname } "
230
244
)
231
245
232
- if untar :
233
- if fname .endswith (".tar.gz" ):
234
- fname = pathlib .Path (fname )
235
- # The 2 `.with_suffix()` are because of `.tar.gz` as pathlib
236
- # considers it as 2 suffixes.
237
- fname = fname .with_suffix ("" ).with_suffix ("" )
238
- fname = str (fname )
239
- untar_fpath = os .path .join (datadir , fname )
240
- fpath = untar_fpath + ".tar.gz"
246
+ if extract or untar :
247
+ if provided_fname :
248
+ if "." in fname :
249
+ download_target = os .path .join (datadir , fname )
250
+ fname = fname [: fname .find ("." )]
251
+ extraction_dir = os .path .join (datadir , fname + "_extracted" )
252
+ else :
253
+ extraction_dir = os .path .join (datadir , fname )
254
+ download_target = os .path .join (datadir , fname + "_archive" )
255
+ else :
256
+ extraction_dir = os .path .join (datadir , fname )
257
+ download_target = os .path .join (datadir , fname + "_archive" )
241
258
else :
242
- fpath = os .path .join (datadir , fname )
259
+ download_target = os .path .join (datadir , fname )
243
260
244
261
if force_download :
245
262
download = True
246
- elif os .path .exists (fpath ):
263
+ elif os .path .exists (download_target ):
247
264
# File found in cache.
248
265
download = False
249
266
# Verify integrity if a hash was provided.
250
267
if file_hash is not None :
251
- if not validate_file (fpath , file_hash , algorithm = hash_algorithm ):
268
+ if not validate_file (
269
+ download_target , file_hash , algorithm = hash_algorithm
270
+ ):
252
271
io_utils .print_msg (
253
272
"A local file was found, but it seems to be "
254
273
f"incomplete or outdated because the { hash_algorithm } "
@@ -288,43 +307,42 @@ def __call__(self, block_num, block_size, total_size):
288
307
error_msg = "URL fetch failure on {}: {} -- {}"
289
308
try :
290
309
try :
291
- urlretrieve (origin , fpath , DLProgbar ())
310
+ urlretrieve (origin , download_target , DLProgbar ())
292
311
except urllib .error .HTTPError as e :
293
312
raise Exception (error_msg .format (origin , e .code , e .msg ))
294
313
except urllib .error .URLError as e :
295
314
raise Exception (error_msg .format (origin , e .errno , e .reason ))
296
315
except (Exception , KeyboardInterrupt ):
297
- if os .path .exists (fpath ):
298
- os .remove (fpath )
316
+ if os .path .exists (download_target ):
317
+ os .remove (download_target )
299
318
raise
300
319
301
320
# Validate download if succeeded and user provided an expected hash
302
321
# Security conscious users would get the hash of the file from a
303
322
# separate channel and pass it to this API to prevent MITM / corruption:
304
- if os .path .exists (fpath ) and file_hash is not None :
305
- if not validate_file (fpath , file_hash , algorithm = hash_algorithm ):
323
+ if os .path .exists (download_target ) and file_hash is not None :
324
+ if not validate_file (
325
+ download_target , file_hash , algorithm = hash_algorithm
326
+ ):
306
327
raise ValueError (
307
328
"Incomplete or corrupted file detected. "
308
329
f"The { hash_algorithm } "
309
330
"file hash does not match the provided value "
310
331
f"of { file_hash } ."
311
332
)
312
333
313
- if untar :
314
- if not os .path .exists (untar_fpath ):
315
- status = extract_archive (fpath , datadir , archive_format = "tar" )
316
- if not status :
317
- warnings .warn ("Could not extract archive." , stacklevel = 2 )
318
- return untar_fpath
334
+ if extract or untar :
335
+ if untar :
336
+ archive_format = "tar"
319
337
320
- if extract :
321
- status = extract_archive (fpath , datadir , archive_format )
338
+ status = extract_archive (
339
+ download_target , extraction_dir , archive_format
340
+ )
322
341
if not status :
323
342
warnings .warn ("Could not extract archive." , stacklevel = 2 )
343
+ return extraction_dir
324
344
325
- # TODO: return extracted fpath if we extracted an archive,
326
- # rather than the archive path.
327
- return fpath
345
+ return download_target
328
346
329
347
330
348
def resolve_hasher (algorithm , file_hash = None ):
0 commit comments