keras/datasets/imdb.py
2015-03-27 17:59:42 -07:00

41 lines
1.0 KiB
Python

import cPickle
import gzip
from data_utils import get_file
import random
def load_data(path="imdb.pkl", nb_words=100000, maxlen=None, test_split=0.2, seed=113):
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/imdb.pkl")
if path.endswith(".gz"):
f = gzip.open(path, 'rb')
else:
f = open(path, 'rb')
X, labels = cPickle.load(f)
f.close()
random.seed(seed)
random.shuffle(X)
random.seed(seed)
random.shuffle(labels)
if maxlen:
new_X = []
new_labels = []
for x, y in zip(X, labels):
if len(x) < maxlen:
new_X.append(x)
new_labels.append(y)
X = new_X
labels = new_labels
X = [[1 if w >= nb_words else w for w in x] for x in X]
X_train = X[:int(len(X)*(1-test_split))]
y_train = labels[:int(len(X)*(1-test_split))]
X_test = X[int(len(X)*(1-test_split)):]
y_test = labels[int(len(X)*(1-test_split)):]
return (X_train, y_train), (X_test, y_test)