fix: upload cifar10 and cifar100 datasets (#20185)

* fix: naming of directories for cifar 10 and 100

* fix: naming of directories for cifar 10 and 100

---------

Co-authored-by: al mond <a@b.com>
This commit is contained in:
ghsanti 2024-08-29 18:51:33 +01:00 committed by GitHub
parent 8f6d71565b
commit 4c71314cfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 4 deletions

@ -60,12 +60,12 @@ def load_data():
assert y_test.shape == (10000, 1) assert y_test.shape == (10000, 1)
``` ```
""" """
dirname = "cifar-10-batches-py" dirname = "cifar-10-batches-py-target"
origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
path = get_file( path = get_file(
fname=dirname, fname=dirname,
origin=origin, origin=origin,
untar=True, extract=True,
file_hash=( # noqa: E501 file_hash=( # noqa: E501
"6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce" "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
), ),
@ -76,6 +76,8 @@ def load_data():
x_train = np.empty((num_train_samples, 3, 32, 32), dtype="uint8") x_train = np.empty((num_train_samples, 3, 32, 32), dtype="uint8")
y_train = np.empty((num_train_samples,), dtype="uint8") y_train = np.empty((num_train_samples,), dtype="uint8")
# batches are within an inner folder
path = os.path.join(path, "cifar-10-batches-py")
for i in range(1, 6): for i in range(1, 6):
fpath = os.path.join(path, "data_batch_" + str(i)) fpath = os.path.join(path, "data_batch_" + str(i))
( (

@ -58,17 +58,18 @@ def load_data(label_mode="fine"):
f"Received: label_mode={label_mode}." f"Received: label_mode={label_mode}."
) )
dirname = "cifar-100-python" dirname = "cifar-100-python-target"
origin = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" origin = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
path = get_file( path = get_file(
fname=dirname, fname=dirname,
origin=origin, origin=origin,
untar=True, extract=True,
file_hash=( # noqa: E501 file_hash=( # noqa: E501
"85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7" "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
), ),
) )
path = os.path.join(path, "cifar-100-python")
fpath = os.path.join(path, "train") fpath = os.path.join(path, "train")
x_train, y_train = load_batch(fpath, label_key=label_mode + "_labels") x_train, y_train = load_batch(fpath, label_key=label_mode + "_labels")