Merge branch 'master' of github.com:jnphilipp/keras
This commit is contained in:
commit
37f4d11ea9
@ -1,15 +1,15 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
import sys
|
import sys
|
||||||
import six.moves.cPickle
|
from six.moves import cPickle
|
||||||
from six.moves import range
|
from six.moves import range
|
||||||
|
|
||||||
def load_batch(fpath, label_key='labels'):
|
def load_batch(fpath, label_key='labels'):
|
||||||
f = open(fpath, 'rb')
|
f = open(fpath, 'rb')
|
||||||
if sys.version_info < (3,):
|
if sys.version_info < (3,):
|
||||||
d = six.moves.cPickle.load(f)
|
d = cPickle.load(f)
|
||||||
else:
|
else:
|
||||||
d = six.moves.cPickle.load(f, encoding="bytes")
|
d = cPickle.load(f, encoding="bytes")
|
||||||
# decode utf8
|
# decode utf8
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
del(d[k])
|
del(d[k])
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
import six.moves.cPickle
|
import cPickle
|
||||||
import gzip
|
import gzip
|
||||||
from .data_utils import get_file
|
from .data_utils import get_file
|
||||||
import random
|
import random
|
||||||
@ -17,7 +17,7 @@ def load_data(path="imdb.pkl", nb_words=None, skip_top=0, maxlen=None, test_spli
|
|||||||
else:
|
else:
|
||||||
f = open(path, 'rb')
|
f = open(path, 'rb')
|
||||||
|
|
||||||
X, labels = six.moves.cPickle.load(f)
|
X, labels = cPickle.load(f)
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import gzip
|
import gzip
|
||||||
from .data_utils import get_file
|
from .data_utils import get_file
|
||||||
import six.moves.cPickle
|
from six.moves import cPickle
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
@ -14,9 +14,9 @@ def load_data(path="mnist.pkl.gz"):
|
|||||||
f = open(path, 'rb')
|
f = open(path, 'rb')
|
||||||
|
|
||||||
if sys.version_info < (3,):
|
if sys.version_info < (3,):
|
||||||
data = six.moves.cPickle.load(f)
|
data = cPickle.load(f)
|
||||||
else:
|
else:
|
||||||
data = six.moves.cPickle.load(f, encoding="bytes")
|
data = cPickle.load(f, encoding="bytes")
|
||||||
|
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from .data_utils import get_file
|
|||||||
import string
|
import string
|
||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
import six.moves.cPickle
|
from six.moves import cPickle
|
||||||
from six.moves import zip
|
from six.moves import zip
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -78,8 +78,8 @@ def make_reuters_dataset(path=os.path.join('datasets', 'temp', 'reuters21578'),
|
|||||||
dataset = (X, labels)
|
dataset = (X, labels)
|
||||||
print('-')
|
print('-')
|
||||||
print('Saving...')
|
print('Saving...')
|
||||||
six.moves.cPickle.dump(dataset, open(os.path.join('datasets', 'data', 'reuters.pkl'), 'w'))
|
cPickle.dump(dataset, open(os.path.join('datasets', 'data', 'reuters.pkl'), 'w'))
|
||||||
six.moves.cPickle.dump(tokenizer.word_index, open(os.path.join('datasets', 'data', 'reuters_word_index.pkl'), 'w'))
|
cPickle.dump(tokenizer.word_index, open(os.path.join('datasets', 'data', 'reuters_word_index.pkl'), 'w'))
|
||||||
|
|
||||||
|
|
||||||
def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113,
|
def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113,
|
||||||
@ -88,7 +88,7 @@ def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_s
|
|||||||
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters.pkl")
|
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters.pkl")
|
||||||
f = open(path, 'rb')
|
f = open(path, 'rb')
|
||||||
|
|
||||||
X, labels = six.moves.cPickle.load(f)
|
X, labels = cPickle.load(f)
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
@ -140,7 +140,7 @@ def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_s
|
|||||||
def get_word_index(path="reuters_word_index.pkl"):
|
def get_word_index(path="reuters_word_index.pkl"):
|
||||||
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl")
|
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl")
|
||||||
f = open(path, 'rb')
|
f = open(path, 'rb')
|
||||||
return six.moves.cPickle.load(f)
|
return cPickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -473,6 +473,7 @@ class JZS1(Recurrent):
|
|||||||
self.W_z, self.b_z,
|
self.W_z, self.b_z,
|
||||||
self.W_r, self.U_r, self.b_r,
|
self.W_r, self.U_r, self.b_r,
|
||||||
self.U_h, self.b_h,
|
self.U_h, self.b_h,
|
||||||
|
self.Pmat
|
||||||
]
|
]
|
||||||
|
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
@ -579,6 +580,7 @@ class JZS2(Recurrent):
|
|||||||
self.W_z, self.U_z, self.b_z,
|
self.W_z, self.U_z, self.b_z,
|
||||||
self.U_r, self.b_r,
|
self.U_r, self.b_r,
|
||||||
self.W_h, self.U_h, self.b_h,
|
self.W_h, self.U_h, self.b_h,
|
||||||
|
self.Pmat
|
||||||
]
|
]
|
||||||
|
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user