41 lines
1.0 KiB
Python
41 lines
1.0 KiB
Python
import cPickle
|
|
import gzip
|
|
from data_utils import get_file
|
|
import random
|
|
|
|
def load_data(path="imdb.pkl", nb_words=100000, maxlen=None, test_split=0.2, seed=113):
|
|
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/imdb.pkl")
|
|
|
|
if path.endswith(".gz"):
|
|
f = gzip.open(path, 'rb')
|
|
else:
|
|
f = open(path, 'rb')
|
|
|
|
X, labels = cPickle.load(f)
|
|
f.close()
|
|
|
|
random.seed(seed)
|
|
random.shuffle(X)
|
|
random.seed(seed)
|
|
random.shuffle(labels)
|
|
|
|
if maxlen:
|
|
new_X = []
|
|
new_labels = []
|
|
for x, y in zip(X, labels):
|
|
if len(x) < maxlen:
|
|
new_X.append(x)
|
|
new_labels.append(y)
|
|
X = new_X
|
|
labels = new_labels
|
|
|
|
X = [[1 if w >= nb_words else w for w in x] for x in X]
|
|
X_train = X[:int(len(X)*(1-test_split))]
|
|
y_train = labels[:int(len(X)*(1-test_split))]
|
|
|
|
X_test = X[int(len(X)*(1-test_split)):]
|
|
y_test = labels[int(len(X)*(1-test_split)):]
|
|
|
|
return (X_train, y_train), (X_test, y_test)
|
|
|