get_file() with tar, tgz, tar.bz, zip and sha256, resolves #5861. (#5882)

* get_file() with tar, tgz, tar.bz, zip and sha256, resolves #5861.

The changes were designed to preserve backwards compatibility while adding support
for .tar.gz, .tgz, .tar.bz, and .zip files.
sha256 hash is now supported in addition to md5.

* get_file() improve large file performance #5861.

* getfile() extract parameter fix (#5861)

* extract_archive() py3 fix (#5861)

* get_file() tarfile fix (#5861)

* data_utils.py and data_utils_test.py updated based on review (#5861)
# This is a combination of 4 commits.
# The first commit's message is:
get_file() with tar, tgz, tar.bz, zip and sha256, resolves #5861.

The changes were designed to preserve backwards compatibility while adding support
for .tar.gz, .tgz, .tar.bz, and .zip files.
Adds extract_archive() and hash_file() functions.
sha256 hash is now supported in addition to md5.
adds data_utils_test.py to test new functionality

# This is the 2nd commit message:

extract_archive() redundant open (#5861)

# This is the 3rd commit message:

data_utils.py and data_utils_test.py updated based on review (#5861)
test creates its own tiny file to download and extract locally.
test covers md5 sha256 zip and tar
_hash_file() now private
_extract_archive() now private

# This is the 4th commit message:

data_utils.py and data_utils_test.py updated based on review (#5861)
test creates its own tiny file to download and extract locally.
test covers md5 sha256 zip and tar
_hash_file() now private
_extract_archive() now private

* data_utils.py and data_utils_test.py updated based on review (#5861)

* data_utils.py get_file() cache_dir docs (#5861)

* data_utils.py address docs comments (#5861)

* get_file() comment link, path, & typo fix
This commit is contained in:
Andrew Hundt 2017-04-03 23:23:49 -04:00 committed by François Chollet
parent 64d2421599
commit 4fe78f3400
2 changed files with 209 additions and 35 deletions

@ -4,10 +4,12 @@ from __future__ import print_function
import functools
import tarfile
import zipfile
import os
import sys
import shutil
import hashlib
import six
from six.moves.urllib.request import urlopen
from six.moves.urllib.error import URLError
from six.moves.urllib.error import HTTPError
@ -55,24 +57,105 @@ else:
from six.moves.urllib.request import urlretrieve
def get_file(fname, origin, untar=False,
md5_hash=None, cache_subdir='datasets'):
"""Downloads a file from a URL if it not already in the cache.
Passing the MD5 hash will verify the file after download
as well as if it is already present in the cache.
def _extract_archive(file_path, path='.', archive_format='auto'):
"""Extracts an archive if it matches the tar, tar.gz, tar.bz, or zip formats
# Arguments
fname: name of the file
origin: original URL of the file
untar: boolean, whether the file should be decompressed
md5_hash: MD5 hash of the file for verification
cache_subdir: directory being used as the cache
file_path: path to the archive file
path: path to extract the archive file
archive_format: Archive format to try for extracting the file.
Options are 'auto', 'tar', 'zip', and None.
'tar' includes tar, tar.gz, and tar.bz files.
The default 'auto' is ['tar', 'zip'].
None or an empty list will return no matches found.
# Return:
True if a match was found and an archive extraction was completed,
False otherwise.
"""
if archive_format is None:
return False
if archive_format is 'auto':
archive_format = ['tar', 'zip']
if isinstance(archive_format, six.string_types):
archive_format = [archive_format]
for archive_type in archive_format:
if archive_type is 'tar':
open_fn = tarfile.open
is_match_fn = tarfile.is_tarfile
if archive_type is 'zip':
open_fn = zipfile.ZipFile
is_match_fn = zipfile.is_zipfile
if is_match_fn(file_path):
with open_fn(file_path) as archive:
try:
archive.extractall(path)
except (tarfile.TarError, RuntimeError,
KeyboardInterrupt) as e:
if os.path.exists(path):
if os.path.isfile(path):
os.remove(path)
else:
shutil.rmtree(path)
raise
return True
return False
def get_file(fname, origin, untar=False,
md5_hash=None, cache_subdir='datasets',
file_hash=None,
hash_algorithm='auto',
extract=False,
archive_format='auto',
cache_dir=None):
"""Downloads a file from a URL if it not already in the cache.
By default the file at the url `origin` is downloaded to the
cache_dir `~/.keras`, placed in the cache_subdir `datasets`,
and given the filename `fname`. The final location of a file
`example.txt` would therefore be `~/.keras/datasets/example.txt`.
Files in tar, tar.gz, tar.bz, and zip formats can also be extracted.
Passing a hash will verify the file after download. The command line
programs `shasum` and `sha256sum` can compute the hash.
# Arguments
fname: Name of the file. If an absolute path `/path/to/file.txt` is
specified the file will be saved at that location.
origin: Original URL of the file.
untar: Deprecated in favor of 'extract'.
boolean, whether the file should be decompressed
md5_hash: Deprecated in favor of 'file_hash'.
md5 hash of the file for verification
file_hash: The expected hash string of the file after download.
The sha256 and md5 hash algorithms are both supported.
cache_subdir: Subdirectory under the Keras cache dir where the file is
saved. If an absolute path `/path/to/folder` is
specified the file will be saved at that location.
hash_algorithm: Select the hash algorithm to verify the file.
options are 'md5', 'sha256', and 'auto'.
The default 'auto' detects the hash algorithm in use.
extract: True tries extracting the file as an Archive, like tar or zip.
archive_format: Archive format to try for extracting the file.
Options are 'auto', 'tar', 'zip', and None.
'tar' includes tar, tar.gz, and tar.bz files.
The default 'auto' is ['tar', 'zip'].
None or an empty list will return no matches found.
cache_dir: Location to store cached files, when None it
defaults to the [Keras Directory](/faq/#where-is-the-keras-configuration-filed-stored).
# Returns
Path to the downloaded file
"""
datadir_base = os.path.expanduser(os.path.join('~', '.keras'))
if cache_dir is None:
cache_dir = os.path.expanduser(os.path.join('~', '.keras'))
if md5_hash is not None and file_hash is None:
file_hash = md5_hash
hash_algorithm = 'md5'
datadir_base = os.path.expanduser(cache_dir)
if not os.access(datadir_base, os.W_OK):
datadir_base = os.path.join('/tmp', '.keras')
datadir = os.path.join(datadir_base, cache_subdir)
@ -88,10 +171,12 @@ def get_file(fname, origin, untar=False,
download = False
if os.path.exists(fpath):
# File found; verify integrity if a hash was provided.
if md5_hash is not None:
if not validate_file(fpath, md5_hash):
if file_hash is not None:
if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
print('A local file was found, but it seems to be '
'incomplete or outdated.')
'incomplete or outdated because the ' + hash_algorithm +
' file hash does not match the original value of ' +
file_hash + ' so we will re-download the data.')
download = True
else:
download = True
@ -123,38 +208,68 @@ def get_file(fname, origin, untar=False,
if untar:
if not os.path.exists(untar_fpath):
print('Untaring file...')
tfile = tarfile.open(fpath, 'r:gz')
try:
tfile.extractall(path=datadir)
except (Exception, KeyboardInterrupt) as e:
if os.path.exists(untar_fpath):
if os.path.isfile(untar_fpath):
os.remove(untar_fpath)
else:
shutil.rmtree(untar_fpath)
raise
tfile.close()
_extract_archive(fpath, datadir, archive_format='tar')
return untar_fpath
if extract:
_extract_archive(fpath, datadir, archive_format)
return fpath
def validate_file(fpath, md5_hash):
"""Validates a file against a MD5 hash.
def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
"""Calculates a file sha256 or md5 hash.
# Example
```python
>>> from keras.data_utils import _hash_file
>>> _hash_file('/path/to/file.zip')
'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
```
# Arguments
fpath: path to the file being validated
md5_hash: the MD5 hash being validated against
algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'.
The default 'auto' detects the hash algorithm in use.
chunk_size: Bytes to read at a time, important for large files.
# Returns
The file hash
"""
if (algorithm is 'sha256') or (algorithm is 'auto' and len(hash) is 64):
hasher = hashlib.sha256()
else:
hasher = hashlib.md5()
with open(fpath, 'rb') as fpath_file:
for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
hasher.update(chunk)
return hasher.hexdigest()
def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
"""Validates a file against a sha256 or md5 hash.
# Arguments
fpath: path to the file being validated
file_hash: The expected hash string of the file.
The sha256 and md5 hash algorithms are both supported.
algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'.
The default 'auto' detects the hash algorithm in use.
chunk_size: Bytes to read at a time, important for large files.
# Returns
Whether the file is valid
"""
hasher = hashlib.md5()
with open(fpath, 'rb') as f:
buf = f.read()
hasher.update(buf)
if str(hasher.hexdigest()) == str(md5_hash):
if ((algorithm is 'sha256') or
(algorithm is 'auto' and len(file_hash) is 64)):
hasher = 'sha256'
else:
hasher = 'md5'
if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
return True
else:
return False

@ -0,0 +1,59 @@
"""Tests for functions in data_utils.py.
"""
import os
import pytest
import tarfile
import zipfile
from six.moves.urllib.request import pathname2url
from six.moves.urllib.parse import urljoin
from keras.utils.data_utils import get_file
from keras.utils.data_utils import validate_file
from keras.utils.data_utils import _hash_file
from keras import activations
from keras import regularizers
def test_data_utils():
"""Tests get_file from a url, plus extraction and validation.
"""
dirname = 'data_utils'
with open('test.txt', 'w') as text_file:
text_file.write('Float like a butterfly, sting like a bee.')
with tarfile.open('test.tar.gz', 'w:gz') as tar_file:
tar_file.add('test.txt')
with zipfile.ZipFile('test.zip', 'w') as zip_file:
zip_file.write('test.txt')
origin = urljoin('file://', pathname2url(os.path.abspath('test.tar.gz')))
path = get_file(dirname, origin, untar=True)
filepath = path + '.tar.gz'
hashval_sha256 = _hash_file(filepath)
hashval_md5 = _hash_file(filepath, algorithm='md5')
path = get_file(dirname, origin, md5_hash=hashval_md5, untar=True)
path = get_file(filepath, origin, file_hash=hashval_sha256, extract=True)
assert os.path.exists(filepath)
assert validate_file(filepath, hashval_sha256)
assert validate_file(filepath, hashval_md5)
os.remove(filepath)
os.remove('test.tar.gz')
origin = urljoin('file://', pathname2url(os.path.abspath('test.zip')))
hashval_sha256 = _hash_file('test.zip')
hashval_md5 = _hash_file('test.zip', algorithm='md5')
path = get_file(dirname, origin, md5_hash=hashval_md5, extract=True)
path = get_file(dirname, origin, file_hash=hashval_sha256, extract=True)
assert os.path.exists(path)
assert validate_file(path, hashval_sha256)
assert validate_file(path, hashval_md5)
os.remove(path)
os.remove('test.txt')
os.remove('test.zip')
if __name__ == '__main__':
pytest.main([__file__])