Merge branch 'master' of github.com:jnphilipp/keras

This commit is contained in:
jnphilipp 2015-09-04 13:49:52 +02:00
commit 37f4d11ea9
5 changed files with 15 additions and 13 deletions

@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
import sys
import six.moves.cPickle
from six.moves import cPickle
from six.moves import range
def load_batch(fpath, label_key='labels'):
f = open(fpath, 'rb')
if sys.version_info < (3,):
d = six.moves.cPickle.load(f)
d = cPickle.load(f)
else:
d = six.moves.cPickle.load(f, encoding="bytes")
d = cPickle.load(f, encoding="bytes")
# decode utf8
for k, v in d.items():
del(d[k])

@ -1,5 +1,5 @@
from __future__ import absolute_import
import six.moves.cPickle
import cPickle
import gzip
from .data_utils import get_file
import random
@ -17,7 +17,7 @@ def load_data(path="imdb.pkl", nb_words=None, skip_top=0, maxlen=None, test_spli
else:
f = open(path, 'rb')
X, labels = six.moves.cPickle.load(f)
X, labels = cPickle.load(f)
f.close()
np.random.seed(seed)

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import gzip
from .data_utils import get_file
import six.moves.cPickle
from six.moves import cPickle
import sys
@ -14,9 +14,9 @@ def load_data(path="mnist.pkl.gz"):
f = open(path, 'rb')
if sys.version_info < (3,):
data = six.moves.cPickle.load(f)
data = cPickle.load(f)
else:
data = six.moves.cPickle.load(f, encoding="bytes")
data = cPickle.load(f, encoding="bytes")
f.close()

@ -5,7 +5,7 @@ from .data_utils import get_file
import string
import random
import os
import six.moves.cPickle
from six.moves import cPickle
from six.moves import zip
import numpy as np
@ -78,8 +78,8 @@ def make_reuters_dataset(path=os.path.join('datasets', 'temp', 'reuters21578'),
dataset = (X, labels)
print('-')
print('Saving...')
six.moves.cPickle.dump(dataset, open(os.path.join('datasets', 'data', 'reuters.pkl'), 'w'))
six.moves.cPickle.dump(tokenizer.word_index, open(os.path.join('datasets', 'data', 'reuters_word_index.pkl'), 'w'))
cPickle.dump(dataset, open(os.path.join('datasets', 'data', 'reuters.pkl'), 'w'))
cPickle.dump(tokenizer.word_index, open(os.path.join('datasets', 'data', 'reuters_word_index.pkl'), 'w'))
def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113,
@ -88,7 +88,7 @@ def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_s
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters.pkl")
f = open(path, 'rb')
X, labels = six.moves.cPickle.load(f)
X, labels = cPickle.load(f)
f.close()
np.random.seed(seed)
@ -140,7 +140,7 @@ def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_s
def get_word_index(path="reuters_word_index.pkl"):
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl")
f = open(path, 'rb')
return six.moves.cPickle.load(f)
return cPickle.load(f)
if __name__ == "__main__":

@ -473,6 +473,7 @@ class JZS1(Recurrent):
self.W_z, self.b_z,
self.W_r, self.U_r, self.b_r,
self.U_h, self.b_h,
self.Pmat
]
if weights is not None:
@ -579,6 +580,7 @@ class JZS2(Recurrent):
self.W_z, self.U_z, self.b_z,
self.U_r, self.b_r,
self.W_h, self.U_h, self.b_h,
self.Pmat
]
if weights is not None: