keras/keras_core/utils/file_utils_test.py
2023-04-30 15:11:54 -07:00

170 lines
5.7 KiB
Python

import os
import tarfile
import urllib
import zipfile
from keras_core.testing import test_case
from keras_core.utils import file_utils
class TestGetFile(test_case.TestCase):
def test_get_file_and_validate_it(self):
"""Tests get_file from a url, plus extraction and validation."""
dest_dir = self.get_temp_dir()
orig_dir = self.get_temp_dir()
text_file_path = os.path.join(orig_dir, "test.txt")
zip_file_path = os.path.join(orig_dir, "test.zip")
tar_file_path = os.path.join(orig_dir, "test.tar.gz")
with open(text_file_path, "w") as text_file:
text_file.write("Float like a butterfly, sting like a bee.")
with tarfile.open(tar_file_path, "w:gz") as tar_file:
tar_file.add(text_file_path)
with zipfile.ZipFile(zip_file_path, "w") as zip_file:
zip_file.write(text_file_path)
origin = urllib.parse.urljoin(
"file://",
urllib.request.pathname2url(os.path.abspath(tar_file_path)),
)
path = file_utils.get_file(
"test.txt", origin, untar=True, cache_subdir=dest_dir
)
filepath = path + ".tar.gz"
hashval_sha256 = file_utils.hash_file(filepath)
hashval_md5 = file_utils.hash_file(filepath, algorithm="md5")
path = file_utils.get_file(
"test.txt",
origin,
md5_hash=hashval_md5,
untar=True,
cache_subdir=dest_dir,
)
path = file_utils.get_file(
filepath,
origin,
file_hash=hashval_sha256,
extract=True,
cache_subdir=dest_dir,
)
self.assertTrue(os.path.exists(filepath))
self.assertTrue(file_utils.validate_file(filepath, hashval_sha256))
self.assertTrue(file_utils.validate_file(filepath, hashval_md5))
os.remove(filepath)
origin = urllib.parse.urljoin(
"file://",
urllib.request.pathname2url(os.path.abspath(zip_file_path)),
)
hashval_sha256 = file_utils.hash_file(zip_file_path)
hashval_md5 = file_utils.hash_file(zip_file_path, algorithm="md5")
path = file_utils.get_file(
"test",
origin,
md5_hash=hashval_md5,
extract=True,
cache_subdir=dest_dir,
)
path = file_utils.get_file(
"test",
origin,
file_hash=hashval_sha256,
extract=True,
cache_subdir=dest_dir,
)
self.assertTrue(os.path.exists(path))
self.assertTrue(file_utils.validate_file(path, hashval_sha256))
self.assertTrue(file_utils.validate_file(path, hashval_md5))
os.remove(path)
for file_path, extract in [
(text_file_path, False),
(tar_file_path, True),
(zip_file_path, True),
]:
origin = urllib.parse.urljoin(
"file://",
urllib.request.pathname2url(os.path.abspath(file_path)),
)
hashval_sha256 = file_utils.hash_file(file_path)
path = file_utils.get_file(
origin=origin,
file_hash=hashval_sha256,
extract=extract,
cache_subdir=dest_dir,
)
self.assertTrue(os.path.exists(path))
self.assertTrue(file_utils.validate_file(path, hashval_sha256))
os.remove(path)
with self.assertRaisesRegexp(
ValueError, 'Please specify the "origin".*'
):
_ = file_utils.get_file()
def test_get_file_with_tgz_extension(self):
"""Tests get_file from a url, plus extraction and validation."""
dest_dir = self.get_temp_dir()
orig_dir = dest_dir
text_file_path = os.path.join(orig_dir, "test.txt")
tar_file_path = os.path.join(orig_dir, "test.tar.gz")
with open(text_file_path, "w") as text_file:
text_file.write("Float like a butterfly, sting like a bee.")
with tarfile.open(tar_file_path, "w:gz") as tar_file:
tar_file.add(text_file_path)
origin = urllib.parse.urljoin(
"file://",
urllib.request.pathname2url(os.path.abspath(tar_file_path)),
)
path = file_utils.get_file(
"test.txt.tar.gz", origin, untar=True, cache_subdir=dest_dir
)
self.assertTrue(path.endswith(".txt"))
self.assertTrue(os.path.exists(path))
def test_get_file_with_integrity_check(self):
"""Tests get_file with validation before download."""
orig_dir = self.get_temp_dir()
file_path = os.path.join(orig_dir, "test.txt")
with open(file_path, "w") as text_file:
text_file.write("Float like a butterfly, sting like a bee.")
hashval = file_utils.hash_file(file_path)
origin = urllib.parse.urljoin(
"file://", urllib.request.pathname2url(os.path.abspath(file_path))
)
path = file_utils.get_file("test.txt", origin, file_hash=hashval)
self.assertTrue(os.path.exists(path))
def test_get_file_with_failed_integrity_check(self):
"""Tests get_file with validation before download."""
orig_dir = self.get_temp_dir()
file_path = os.path.join(orig_dir, "test.txt")
with open(file_path, "w") as text_file:
text_file.write("Float like a butterfly, sting like a bee.")
hashval = "0" * 64
origin = urllib.parse.urljoin(
"file://", urllib.request.pathname2url(os.path.abspath(file_path))
)
with self.assertRaisesRegex(
ValueError, "Incomplete or corrupted file.*"
):
_ = file_utils.get_file("test.txt", origin, file_hash=hashval)