Merge branch 'master' of github.com:jnphilipp/keras

This commit is contained in:
jnphilipp 2015-09-04 13:49:52 +02:00
commit 37f4d11ea9
5 changed files with 15 additions and 13 deletions

@ -1,15 +1,15 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import from __future__ import absolute_import
import sys import sys
import six.moves.cPickle from six.moves import cPickle
from six.moves import range from six.moves import range
def load_batch(fpath, label_key='labels'): def load_batch(fpath, label_key='labels'):
f = open(fpath, 'rb') f = open(fpath, 'rb')
if sys.version_info < (3,): if sys.version_info < (3,):
d = six.moves.cPickle.load(f) d = cPickle.load(f)
else: else:
d = six.moves.cPickle.load(f, encoding="bytes") d = cPickle.load(f, encoding="bytes")
# decode utf8 # decode utf8
for k, v in d.items(): for k, v in d.items():
del(d[k]) del(d[k])

@ -1,5 +1,5 @@
from __future__ import absolute_import from __future__ import absolute_import
import six.moves.cPickle import cPickle
import gzip import gzip
from .data_utils import get_file from .data_utils import get_file
import random import random
@ -17,7 +17,7 @@ def load_data(path="imdb.pkl", nb_words=None, skip_top=0, maxlen=None, test_spli
else: else:
f = open(path, 'rb') f = open(path, 'rb')
X, labels = six.moves.cPickle.load(f) X, labels = cPickle.load(f)
f.close() f.close()
np.random.seed(seed) np.random.seed(seed)

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import gzip import gzip
from .data_utils import get_file from .data_utils import get_file
import six.moves.cPickle from six.moves import cPickle
import sys import sys
@ -14,9 +14,9 @@ def load_data(path="mnist.pkl.gz"):
f = open(path, 'rb') f = open(path, 'rb')
if sys.version_info < (3,): if sys.version_info < (3,):
data = six.moves.cPickle.load(f) data = cPickle.load(f)
else: else:
data = six.moves.cPickle.load(f, encoding="bytes") data = cPickle.load(f, encoding="bytes")
f.close() f.close()

@ -5,7 +5,7 @@ from .data_utils import get_file
import string import string
import random import random
import os import os
import six.moves.cPickle from six.moves import cPickle
from six.moves import zip from six.moves import zip
import numpy as np import numpy as np
@ -78,8 +78,8 @@ def make_reuters_dataset(path=os.path.join('datasets', 'temp', 'reuters21578'),
dataset = (X, labels) dataset = (X, labels)
print('-') print('-')
print('Saving...') print('Saving...')
six.moves.cPickle.dump(dataset, open(os.path.join('datasets', 'data', 'reuters.pkl'), 'w')) cPickle.dump(dataset, open(os.path.join('datasets', 'data', 'reuters.pkl'), 'w'))
six.moves.cPickle.dump(tokenizer.word_index, open(os.path.join('datasets', 'data', 'reuters_word_index.pkl'), 'w')) cPickle.dump(tokenizer.word_index, open(os.path.join('datasets', 'data', 'reuters_word_index.pkl'), 'w'))
def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113, def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113,
@ -88,7 +88,7 @@ def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_s
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters.pkl") path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters.pkl")
f = open(path, 'rb') f = open(path, 'rb')
X, labels = six.moves.cPickle.load(f) X, labels = cPickle.load(f)
f.close() f.close()
np.random.seed(seed) np.random.seed(seed)
@ -140,7 +140,7 @@ def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_s
def get_word_index(path="reuters_word_index.pkl"): def get_word_index(path="reuters_word_index.pkl"):
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl") path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl")
f = open(path, 'rb') f = open(path, 'rb')
return six.moves.cPickle.load(f) return cPickle.load(f)
if __name__ == "__main__": if __name__ == "__main__":

@ -473,6 +473,7 @@ class JZS1(Recurrent):
self.W_z, self.b_z, self.W_z, self.b_z,
self.W_r, self.U_r, self.b_r, self.W_r, self.U_r, self.b_r,
self.U_h, self.b_h, self.U_h, self.b_h,
self.Pmat
] ]
if weights is not None: if weights is not None:
@ -579,6 +580,7 @@ class JZS2(Recurrent):
self.W_z, self.U_z, self.b_z, self.W_z, self.U_z, self.b_z,
self.U_r, self.b_r, self.U_r, self.b_r,
self.W_h, self.U_h, self.b_h, self.W_h, self.U_h, self.b_h,
self.Pmat
] ]
if weights is not None: if weights is not None: