170 lines
5.7 KiB
Python
170 lines
5.7 KiB
Python
import os
|
|
import tarfile
|
|
import urllib
|
|
import zipfile
|
|
|
|
from keras_core.testing import test_case
|
|
from keras_core.utils import file_utils
|
|
|
|
|
|
class TestGetFile(test_case.TestCase):
|
|
def test_get_file_and_validate_it(self):
|
|
"""Tests get_file from a url, plus extraction and validation."""
|
|
dest_dir = self.get_temp_dir()
|
|
orig_dir = self.get_temp_dir()
|
|
|
|
text_file_path = os.path.join(orig_dir, "test.txt")
|
|
zip_file_path = os.path.join(orig_dir, "test.zip")
|
|
tar_file_path = os.path.join(orig_dir, "test.tar.gz")
|
|
|
|
with open(text_file_path, "w") as text_file:
|
|
text_file.write("Float like a butterfly, sting like a bee.")
|
|
|
|
with tarfile.open(tar_file_path, "w:gz") as tar_file:
|
|
tar_file.add(text_file_path)
|
|
|
|
with zipfile.ZipFile(zip_file_path, "w") as zip_file:
|
|
zip_file.write(text_file_path)
|
|
|
|
origin = urllib.parse.urljoin(
|
|
"file://",
|
|
urllib.request.pathname2url(os.path.abspath(tar_file_path)),
|
|
)
|
|
|
|
path = file_utils.get_file(
|
|
"test.txt", origin, untar=True, cache_subdir=dest_dir
|
|
)
|
|
filepath = path + ".tar.gz"
|
|
hashval_sha256 = file_utils.hash_file(filepath)
|
|
hashval_md5 = file_utils.hash_file(filepath, algorithm="md5")
|
|
path = file_utils.get_file(
|
|
"test.txt",
|
|
origin,
|
|
md5_hash=hashval_md5,
|
|
untar=True,
|
|
cache_subdir=dest_dir,
|
|
)
|
|
path = file_utils.get_file(
|
|
filepath,
|
|
origin,
|
|
file_hash=hashval_sha256,
|
|
extract=True,
|
|
cache_subdir=dest_dir,
|
|
)
|
|
self.assertTrue(os.path.exists(filepath))
|
|
self.assertTrue(file_utils.validate_file(filepath, hashval_sha256))
|
|
self.assertTrue(file_utils.validate_file(filepath, hashval_md5))
|
|
os.remove(filepath)
|
|
|
|
origin = urllib.parse.urljoin(
|
|
"file://",
|
|
urllib.request.pathname2url(os.path.abspath(zip_file_path)),
|
|
)
|
|
|
|
hashval_sha256 = file_utils.hash_file(zip_file_path)
|
|
hashval_md5 = file_utils.hash_file(zip_file_path, algorithm="md5")
|
|
path = file_utils.get_file(
|
|
"test",
|
|
origin,
|
|
md5_hash=hashval_md5,
|
|
extract=True,
|
|
cache_subdir=dest_dir,
|
|
)
|
|
path = file_utils.get_file(
|
|
"test",
|
|
origin,
|
|
file_hash=hashval_sha256,
|
|
extract=True,
|
|
cache_subdir=dest_dir,
|
|
)
|
|
self.assertTrue(os.path.exists(path))
|
|
self.assertTrue(file_utils.validate_file(path, hashval_sha256))
|
|
self.assertTrue(file_utils.validate_file(path, hashval_md5))
|
|
os.remove(path)
|
|
|
|
for file_path, extract in [
|
|
(text_file_path, False),
|
|
(tar_file_path, True),
|
|
(zip_file_path, True),
|
|
]:
|
|
origin = urllib.parse.urljoin(
|
|
"file://",
|
|
urllib.request.pathname2url(os.path.abspath(file_path)),
|
|
)
|
|
hashval_sha256 = file_utils.hash_file(file_path)
|
|
path = file_utils.get_file(
|
|
origin=origin,
|
|
file_hash=hashval_sha256,
|
|
extract=extract,
|
|
cache_subdir=dest_dir,
|
|
)
|
|
self.assertTrue(os.path.exists(path))
|
|
self.assertTrue(file_utils.validate_file(path, hashval_sha256))
|
|
os.remove(path)
|
|
|
|
with self.assertRaisesRegexp(
|
|
ValueError, 'Please specify the "origin".*'
|
|
):
|
|
_ = file_utils.get_file()
|
|
|
|
def test_get_file_with_tgz_extension(self):
|
|
"""Tests get_file from a url, plus extraction and validation."""
|
|
dest_dir = self.get_temp_dir()
|
|
orig_dir = dest_dir
|
|
|
|
text_file_path = os.path.join(orig_dir, "test.txt")
|
|
tar_file_path = os.path.join(orig_dir, "test.tar.gz")
|
|
|
|
with open(text_file_path, "w") as text_file:
|
|
text_file.write("Float like a butterfly, sting like a bee.")
|
|
|
|
with tarfile.open(tar_file_path, "w:gz") as tar_file:
|
|
tar_file.add(text_file_path)
|
|
|
|
origin = urllib.parse.urljoin(
|
|
"file://",
|
|
urllib.request.pathname2url(os.path.abspath(tar_file_path)),
|
|
)
|
|
|
|
path = file_utils.get_file(
|
|
"test.txt.tar.gz", origin, untar=True, cache_subdir=dest_dir
|
|
)
|
|
self.assertTrue(path.endswith(".txt"))
|
|
self.assertTrue(os.path.exists(path))
|
|
|
|
def test_get_file_with_integrity_check(self):
|
|
"""Tests get_file with validation before download."""
|
|
orig_dir = self.get_temp_dir()
|
|
file_path = os.path.join(orig_dir, "test.txt")
|
|
|
|
with open(file_path, "w") as text_file:
|
|
text_file.write("Float like a butterfly, sting like a bee.")
|
|
|
|
hashval = file_utils.hash_file(file_path)
|
|
|
|
origin = urllib.parse.urljoin(
|
|
"file://", urllib.request.pathname2url(os.path.abspath(file_path))
|
|
)
|
|
|
|
path = file_utils.get_file("test.txt", origin, file_hash=hashval)
|
|
self.assertTrue(os.path.exists(path))
|
|
|
|
def test_get_file_with_failed_integrity_check(self):
|
|
"""Tests get_file with validation before download."""
|
|
orig_dir = self.get_temp_dir()
|
|
file_path = os.path.join(orig_dir, "test.txt")
|
|
|
|
with open(file_path, "w") as text_file:
|
|
text_file.write("Float like a butterfly, sting like a bee.")
|
|
|
|
hashval = "0" * 64
|
|
|
|
origin = urllib.parse.urljoin(
|
|
"file://", urllib.request.pathname2url(os.path.abspath(file_path))
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Incomplete or corrupted file.*"
|
|
):
|
|
_ = file_utils.get_file("test.txt", origin, file_hash=hashval)
|