29 lines
704 B
Python
29 lines
704 B
Python
"""Utilities common to CIFAR10 and CIFAR100 datasets."""
|
|
|
|
import _pickle as cPickle
|
|
|
|
|
|
def load_batch(fpath, label_key="labels"):
|
|
"""Internal utility for parsing CIFAR data.
|
|
|
|
Args:
|
|
fpath: path the file to parse.
|
|
label_key: key for label data in the retrieve
|
|
dictionary.
|
|
|
|
Returns:
|
|
A tuple `(data, labels)`.
|
|
"""
|
|
with open(fpath, "rb") as f:
|
|
d = cPickle.load(f, encoding="bytes")
|
|
# decode utf8
|
|
d_decoded = {}
|
|
for k, v in d.items():
|
|
d_decoded[k.decode("utf8")] = v
|
|
d = d_decoded
|
|
data = d["data"]
|
|
labels = d[label_key]
|
|
|
|
data = data.reshape(data.shape[0], 3, 32, 32)
|
|
return data, labels
|