fix: upload cifar10 and cifar100 datasets (#20185)
* fix: naming of directories for cifar 10 and 100 * fix: naming of directories for cifar 10 and 100 --------- Co-authored-by: al mond <a@b.com>
This commit is contained in:
parent
8f6d71565b
commit
4c71314cfa
@ -60,12 +60,12 @@ def load_data():
|
|||||||
assert y_test.shape == (10000, 1)
|
assert y_test.shape == (10000, 1)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
dirname = "cifar-10-batches-py"
|
dirname = "cifar-10-batches-py-target"
|
||||||
origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||||
path = get_file(
|
path = get_file(
|
||||||
fname=dirname,
|
fname=dirname,
|
||||||
origin=origin,
|
origin=origin,
|
||||||
untar=True,
|
extract=True,
|
||||||
file_hash=( # noqa: E501
|
file_hash=( # noqa: E501
|
||||||
"6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
|
"6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
|
||||||
),
|
),
|
||||||
@ -76,6 +76,8 @@ def load_data():
|
|||||||
x_train = np.empty((num_train_samples, 3, 32, 32), dtype="uint8")
|
x_train = np.empty((num_train_samples, 3, 32, 32), dtype="uint8")
|
||||||
y_train = np.empty((num_train_samples,), dtype="uint8")
|
y_train = np.empty((num_train_samples,), dtype="uint8")
|
||||||
|
|
||||||
|
# batches are within an inner folder
|
||||||
|
path = os.path.join(path, "cifar-10-batches-py")
|
||||||
for i in range(1, 6):
|
for i in range(1, 6):
|
||||||
fpath = os.path.join(path, "data_batch_" + str(i))
|
fpath = os.path.join(path, "data_batch_" + str(i))
|
||||||
(
|
(
|
||||||
|
@ -58,17 +58,18 @@ def load_data(label_mode="fine"):
|
|||||||
f"Received: label_mode={label_mode}."
|
f"Received: label_mode={label_mode}."
|
||||||
)
|
)
|
||||||
|
|
||||||
dirname = "cifar-100-python"
|
dirname = "cifar-100-python-target"
|
||||||
origin = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
origin = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
||||||
path = get_file(
|
path = get_file(
|
||||||
fname=dirname,
|
fname=dirname,
|
||||||
origin=origin,
|
origin=origin,
|
||||||
untar=True,
|
extract=True,
|
||||||
file_hash=( # noqa: E501
|
file_hash=( # noqa: E501
|
||||||
"85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
|
"85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
path = os.path.join(path, "cifar-100-python")
|
||||||
fpath = os.path.join(path, "train")
|
fpath = os.path.join(path, "train")
|
||||||
x_train, y_train = load_batch(fpath, label_key=label_mode + "_labels")
|
x_train, y_train = load_batch(fpath, label_key=label_mode + "_labels")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user