* get_file() with tar, tgz, tar.bz, zip and sha256, resolves #5861. The changes were designed to preserve backwards compatibility while adding support for .tar.gz, .tgz, .tar.bz, and .zip files. sha256 hash is now supported in addition to md5. * get_file() improve large file performance #5861. * getfile() extract parameter fix (#5861) * extract_archive() py3 fix (#5861) * get_file() tarfile fix (#5861) * data_utils.py and data_utils_test.py updated based on review (#5861) # This is a combination of 4 commits. # The first commit's message is: get_file() with tar, tgz, tar.bz, zip and sha256, resolves #5861. The changes were designed to preserve backwards compatibility while adding support for .tar.gz, .tgz, .tar.bz, and .zip files. Adds extract_archive() and hash_file() functions. sha256 hash is now supported in addition to md5. adds data_utils_test.py to test new functionality # This is the 2nd commit message: extract_archive() redundant open (#5861) # This is the 3rd commit message: data_utils.py and data_utils_test.py updated based on review (#5861) test creates its own tiny file to download and extract locally. test covers md5 sha256 zip and tar _hash_file() now private _extract_archive() now private # This is the 4th commit message: data_utils.py and data_utils_test.py updated based on review (#5861) test creates its own tiny file to download and extract locally. test covers md5 sha256 zip and tar _hash_file() now private _extract_archive() now private * data_utils.py and data_utils_test.py updated based on review (#5861) * data_utils.py get_file() cache_dir docs (#5861) * data_utils.py address docs comments (#5861) * get_file() comment link, path, & typo fix
This commit is contained in:
parent
64d2421599
commit
4fe78f3400
@ -4,10 +4,12 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import tarfile
|
||||
import zipfile
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import hashlib
|
||||
import six
|
||||
from six.moves.urllib.request import urlopen
|
||||
from six.moves.urllib.error import URLError
|
||||
from six.moves.urllib.error import HTTPError
|
||||
@ -55,24 +57,105 @@ else:
|
||||
from six.moves.urllib.request import urlretrieve
|
||||
|
||||
|
||||
def get_file(fname, origin, untar=False,
|
||||
md5_hash=None, cache_subdir='datasets'):
|
||||
"""Downloads a file from a URL if it not already in the cache.
|
||||
|
||||
Passing the MD5 hash will verify the file after download
|
||||
as well as if it is already present in the cache.
|
||||
def _extract_archive(file_path, path='.', archive_format='auto'):
|
||||
"""Extracts an archive if it matches the tar, tar.gz, tar.bz, or zip formats
|
||||
|
||||
# Arguments
|
||||
fname: name of the file
|
||||
origin: original URL of the file
|
||||
untar: boolean, whether the file should be decompressed
|
||||
md5_hash: MD5 hash of the file for verification
|
||||
cache_subdir: directory being used as the cache
|
||||
file_path: path to the archive file
|
||||
path: path to extract the archive file
|
||||
archive_format: Archive format to try for extracting the file.
|
||||
Options are 'auto', 'tar', 'zip', and None.
|
||||
'tar' includes tar, tar.gz, and tar.bz files.
|
||||
The default 'auto' is ['tar', 'zip'].
|
||||
None or an empty list will return no matches found.
|
||||
|
||||
# Return:
|
||||
True if a match was found and an archive extraction was completed,
|
||||
False otherwise.
|
||||
"""
|
||||
if archive_format is None:
|
||||
return False
|
||||
if archive_format is 'auto':
|
||||
archive_format = ['tar', 'zip']
|
||||
if isinstance(archive_format, six.string_types):
|
||||
archive_format = [archive_format]
|
||||
|
||||
for archive_type in archive_format:
|
||||
if archive_type is 'tar':
|
||||
open_fn = tarfile.open
|
||||
is_match_fn = tarfile.is_tarfile
|
||||
if archive_type is 'zip':
|
||||
open_fn = zipfile.ZipFile
|
||||
is_match_fn = zipfile.is_zipfile
|
||||
|
||||
if is_match_fn(file_path):
|
||||
with open_fn(file_path) as archive:
|
||||
try:
|
||||
archive.extractall(path)
|
||||
except (tarfile.TarError, RuntimeError,
|
||||
KeyboardInterrupt) as e:
|
||||
if os.path.exists(path):
|
||||
if os.path.isfile(path):
|
||||
os.remove(path)
|
||||
else:
|
||||
shutil.rmtree(path)
|
||||
raise
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_file(fname, origin, untar=False,
|
||||
md5_hash=None, cache_subdir='datasets',
|
||||
file_hash=None,
|
||||
hash_algorithm='auto',
|
||||
extract=False,
|
||||
archive_format='auto',
|
||||
cache_dir=None):
|
||||
"""Downloads a file from a URL if it not already in the cache.
|
||||
|
||||
By default the file at the url `origin` is downloaded to the
|
||||
cache_dir `~/.keras`, placed in the cache_subdir `datasets`,
|
||||
and given the filename `fname`. The final location of a file
|
||||
`example.txt` would therefore be `~/.keras/datasets/example.txt`.
|
||||
|
||||
Files in tar, tar.gz, tar.bz, and zip formats can also be extracted.
|
||||
Passing a hash will verify the file after download. The command line
|
||||
programs `shasum` and `sha256sum` can compute the hash.
|
||||
|
||||
# Arguments
|
||||
fname: Name of the file. If an absolute path `/path/to/file.txt` is
|
||||
specified the file will be saved at that location.
|
||||
origin: Original URL of the file.
|
||||
untar: Deprecated in favor of 'extract'.
|
||||
boolean, whether the file should be decompressed
|
||||
md5_hash: Deprecated in favor of 'file_hash'.
|
||||
md5 hash of the file for verification
|
||||
file_hash: The expected hash string of the file after download.
|
||||
The sha256 and md5 hash algorithms are both supported.
|
||||
cache_subdir: Subdirectory under the Keras cache dir where the file is
|
||||
saved. If an absolute path `/path/to/folder` is
|
||||
specified the file will be saved at that location.
|
||||
hash_algorithm: Select the hash algorithm to verify the file.
|
||||
options are 'md5', 'sha256', and 'auto'.
|
||||
The default 'auto' detects the hash algorithm in use.
|
||||
extract: True tries extracting the file as an Archive, like tar or zip.
|
||||
archive_format: Archive format to try for extracting the file.
|
||||
Options are 'auto', 'tar', 'zip', and None.
|
||||
'tar' includes tar, tar.gz, and tar.bz files.
|
||||
The default 'auto' is ['tar', 'zip'].
|
||||
None or an empty list will return no matches found.
|
||||
cache_dir: Location to store cached files, when None it
|
||||
defaults to the [Keras Directory](/faq/#where-is-the-keras-configuration-filed-stored).
|
||||
|
||||
# Returns
|
||||
Path to the downloaded file
|
||||
"""
|
||||
datadir_base = os.path.expanduser(os.path.join('~', '.keras'))
|
||||
if cache_dir is None:
|
||||
cache_dir = os.path.expanduser(os.path.join('~', '.keras'))
|
||||
if md5_hash is not None and file_hash is None:
|
||||
file_hash = md5_hash
|
||||
hash_algorithm = 'md5'
|
||||
datadir_base = os.path.expanduser(cache_dir)
|
||||
if not os.access(datadir_base, os.W_OK):
|
||||
datadir_base = os.path.join('/tmp', '.keras')
|
||||
datadir = os.path.join(datadir_base, cache_subdir)
|
||||
@ -88,10 +171,12 @@ def get_file(fname, origin, untar=False,
|
||||
download = False
|
||||
if os.path.exists(fpath):
|
||||
# File found; verify integrity if a hash was provided.
|
||||
if md5_hash is not None:
|
||||
if not validate_file(fpath, md5_hash):
|
||||
if file_hash is not None:
|
||||
if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
|
||||
print('A local file was found, but it seems to be '
|
||||
'incomplete or outdated.')
|
||||
'incomplete or outdated because the ' + hash_algorithm +
|
||||
' file hash does not match the original value of ' +
|
||||
file_hash + ' so we will re-download the data.')
|
||||
download = True
|
||||
else:
|
||||
download = True
|
||||
@ -123,38 +208,68 @@ def get_file(fname, origin, untar=False,
|
||||
|
||||
if untar:
|
||||
if not os.path.exists(untar_fpath):
|
||||
print('Untaring file...')
|
||||
tfile = tarfile.open(fpath, 'r:gz')
|
||||
try:
|
||||
tfile.extractall(path=datadir)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
if os.path.exists(untar_fpath):
|
||||
if os.path.isfile(untar_fpath):
|
||||
os.remove(untar_fpath)
|
||||
else:
|
||||
shutil.rmtree(untar_fpath)
|
||||
raise
|
||||
tfile.close()
|
||||
_extract_archive(fpath, datadir, archive_format='tar')
|
||||
return untar_fpath
|
||||
|
||||
if extract:
|
||||
_extract_archive(fpath, datadir, archive_format)
|
||||
|
||||
return fpath
|
||||
|
||||
|
||||
def validate_file(fpath, md5_hash):
|
||||
"""Validates a file against a MD5 hash.
|
||||
def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
|
||||
"""Calculates a file sha256 or md5 hash.
|
||||
|
||||
# Example
|
||||
|
||||
```python
|
||||
>>> from keras.data_utils import _hash_file
|
||||
>>> _hash_file('/path/to/file.zip')
|
||||
'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
|
||||
```
|
||||
|
||||
# Arguments
|
||||
fpath: path to the file being validated
|
||||
md5_hash: the MD5 hash being validated against
|
||||
algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'.
|
||||
The default 'auto' detects the hash algorithm in use.
|
||||
chunk_size: Bytes to read at a time, important for large files.
|
||||
|
||||
# Returns
|
||||
The file hash
|
||||
"""
|
||||
if (algorithm is 'sha256') or (algorithm is 'auto' and len(hash) is 64):
|
||||
hasher = hashlib.sha256()
|
||||
else:
|
||||
hasher = hashlib.md5()
|
||||
|
||||
with open(fpath, 'rb') as fpath_file:
|
||||
for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
|
||||
hasher.update(chunk)
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
|
||||
"""Validates a file against a sha256 or md5 hash.
|
||||
|
||||
# Arguments
|
||||
fpath: path to the file being validated
|
||||
file_hash: The expected hash string of the file.
|
||||
The sha256 and md5 hash algorithms are both supported.
|
||||
algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'.
|
||||
The default 'auto' detects the hash algorithm in use.
|
||||
chunk_size: Bytes to read at a time, important for large files.
|
||||
|
||||
# Returns
|
||||
Whether the file is valid
|
||||
"""
|
||||
hasher = hashlib.md5()
|
||||
with open(fpath, 'rb') as f:
|
||||
buf = f.read()
|
||||
hasher.update(buf)
|
||||
if str(hasher.hexdigest()) == str(md5_hash):
|
||||
if ((algorithm is 'sha256') or
|
||||
(algorithm is 'auto' and len(file_hash) is 64)):
|
||||
hasher = 'sha256'
|
||||
else:
|
||||
hasher = 'md5'
|
||||
|
||||
if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
59
tests/keras/utils/data_utils_test.py
Normal file
59
tests/keras/utils/data_utils_test.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""Tests for functions in data_utils.py.
|
||||
"""
|
||||
import os
|
||||
import pytest
|
||||
import tarfile
|
||||
import zipfile
|
||||
from six.moves.urllib.request import pathname2url
|
||||
from six.moves.urllib.parse import urljoin
|
||||
from keras.utils.data_utils import get_file
|
||||
from keras.utils.data_utils import validate_file
|
||||
from keras.utils.data_utils import _hash_file
|
||||
from keras import activations
|
||||
from keras import regularizers
|
||||
|
||||
|
||||
def test_data_utils():
|
||||
"""Tests get_file from a url, plus extraction and validation.
|
||||
"""
|
||||
dirname = 'data_utils'
|
||||
|
||||
with open('test.txt', 'w') as text_file:
|
||||
text_file.write('Float like a butterfly, sting like a bee.')
|
||||
|
||||
with tarfile.open('test.tar.gz', 'w:gz') as tar_file:
|
||||
tar_file.add('test.txt')
|
||||
|
||||
with zipfile.ZipFile('test.zip', 'w') as zip_file:
|
||||
zip_file.write('test.txt')
|
||||
|
||||
origin = urljoin('file://', pathname2url(os.path.abspath('test.tar.gz')))
|
||||
|
||||
path = get_file(dirname, origin, untar=True)
|
||||
filepath = path + '.tar.gz'
|
||||
hashval_sha256 = _hash_file(filepath)
|
||||
hashval_md5 = _hash_file(filepath, algorithm='md5')
|
||||
path = get_file(dirname, origin, md5_hash=hashval_md5, untar=True)
|
||||
path = get_file(filepath, origin, file_hash=hashval_sha256, extract=True)
|
||||
assert os.path.exists(filepath)
|
||||
assert validate_file(filepath, hashval_sha256)
|
||||
assert validate_file(filepath, hashval_md5)
|
||||
os.remove(filepath)
|
||||
os.remove('test.tar.gz')
|
||||
|
||||
origin = urljoin('file://', pathname2url(os.path.abspath('test.zip')))
|
||||
|
||||
hashval_sha256 = _hash_file('test.zip')
|
||||
hashval_md5 = _hash_file('test.zip', algorithm='md5')
|
||||
path = get_file(dirname, origin, md5_hash=hashval_md5, extract=True)
|
||||
path = get_file(dirname, origin, file_hash=hashval_sha256, extract=True)
|
||||
assert os.path.exists(path)
|
||||
assert validate_file(path, hashval_sha256)
|
||||
assert validate_file(path, hashval_md5)
|
||||
|
||||
os.remove(path)
|
||||
os.remove('test.txt')
|
||||
os.remove('test.zip')
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
Loading…
Reference in New Issue
Block a user