Update all examples with new API

This commit is contained in:
Makoto Matsuyama 2015-10-04 18:44:49 -07:00
parent 35d66d672b
commit 2bd4c295d6
17 changed files with 140 additions and 139 deletions

@ -1,9 +1,8 @@
# -*- coding: utf-8 -*-
from __future__ import print_function
from keras.models import Sequential, slice_X
from keras.layers.core import Activation, Dense, RepeatVector
from keras.layers.core import Activation, TimeDistributedDense, RepeatVector
from keras.layers import recurrent
from sklearn.utils import shuffle
import numpy as np
"""
@ -25,18 +24,15 @@ and
http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf
Theoretically it introduces shorter term dependencies between source and target.
Two digits inverted:
+ One layer JZS1 (128 HN), 5k training examples = 99% train/test accuracy in 55 epochs
Three digits inverted:
+ One layer JZS1 (128 HN), 50k training examples = 99% train/test accuracy in 100 epochs
Four digits inverted:
+ One layer JZS1 (128 HN), 400k training examples = 99% train/test accuracy in 20 epochs
Five digits inverted:
+ One layer JZS1 (128 HN), 550k training examples = 99% train/test accuracy in 30 epochs
@ -122,23 +118,32 @@ for i, sentence in enumerate(expected):
y[i] = ctable.encode(sentence, maxlen=DIGITS + 1)
# Shuffle (X, y) in unison as the later parts of X will almost all be larger digits
X, y = shuffle(X, y)
indices = np.arange(len(y))
np.random.shuffle(indices)
X = X[indices]
y = y[indices]
# Explicitly set apart 10% for validation data that we never train over
split_at = len(X) - len(X) / 10
(X_train, X_val) = (slice_X(X, 0, split_at), slice_X(X, split_at))
(y_train, y_val) = (y[:split_at], y[split_at:])
print(X_train.shape)
print(y_train.shape)
print('Build model...')
model = Sequential()
# "Encode" the input sequence using an RNN, producing an output of HIDDEN_SIZE
model.add(RNN(len(chars), HIDDEN_SIZE))
# note: in a situation where your input sequences have a variable length,
# use input_shape=(None, nb_feature).
model.add(RNN(HIDDEN_SIZE, input_shape=(None, len(chars))))
# For the decoder's input, we repeat the encoded input for each time step
model.add(RepeatVector(DIGITS + 1))
# The decoder RNN could be multiple layers stacked or a single layer
for _ in xrange(LAYERS):
model.add(RNN(HIDDEN_SIZE, HIDDEN_SIZE, return_sequences=True))
model.add(RNN(HIDDEN_SIZE, return_sequences=True))
# For each of step of the output sequence, decide which character should be chosen
model.add(Dense(HIDDEN_SIZE, len(chars)))
model.add(TimeDistributedDense(len(chars)))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
@ -148,7 +153,7 @@ for iteration in range(1, 200):
print()
print('-' * 50)
print('Iteration', iteration)
model.fit(X, y, batch_size=BATCH_SIZE, nb_epoch=1, validation_data=(X_val, y_val), show_accuracy=True)
model.fit(X_train, y_train, batch_size=BATCH_SIZE, nb_epoch=1, validation_data=(X_val, y_val), show_accuracy=True)
###
# Select 10 samples from the validation set at random so we can visualize errors
for i in xrange(10):

@ -181,15 +181,15 @@ print('Build model...')
sentrnn = Sequential()
sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, mask_zero=True))
sentrnn.add(RNN(EMBED_HIDDEN_SIZE, SENT_HIDDEN_SIZE, return_sequences=False))
sentrnn.add(RNN(SENT_HIDDEN_SIZE, return_sequences=False))
qrnn = Sequential()
qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE))
qrnn.add(RNN(EMBED_HIDDEN_SIZE, QUERY_HIDDEN_SIZE, return_sequences=False))
qrnn.add(RNN(QUERY_HIDDEN_SIZE, return_sequences=False))
model = Sequential()
model.add(Merge([sentrnn, qrnn], mode='concat'))
model.add(Dense(SENT_HIDDEN_SIZE + QUERY_HIDDEN_SIZE, vocab_size, activation='softmax'))
model.add(Dense(vocab_size, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', class_mode='categorical')

@ -28,16 +28,10 @@ nb_classes = 10
nb_epoch = 200
data_augmentation = True
# shape of the image (SHAPE x SHAPE)
shapex, shapey = 32, 32
# number of convolutional filters to use at each layer
nb_filters = [32, 64]
# level of pooling to perform at each layer (POOL x POOL)
nb_pool = [2, 2]
# level of convolution to perform at each layer (CONV x CONV)
nb_conv = [3, 3]
# input image dimensions
img_rows, img_cols = 32, 32
# the CIFAR10 images are RGB
image_dimensions = 3
img_channels = 3
# the data, shuffled and split between tran and test sets
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
@ -51,28 +45,26 @@ Y_test = np_utils.to_categorical(y_test, nb_classes)
model = Sequential()
model.add(Convolution2D(nb_filters[0], image_dimensions, nb_conv[0], nb_conv[0], border_mode='full'))
model.add(Convolution2D(32, 3, 3, border_mode='full',
input_shape=(img_channels, img_rows, img_cols)))
model.add(Activation('relu'))
model.add(Convolution2D(nb_filters[0], nb_filters[0], nb_conv[0], nb_conv[0]))
model.add(Convolution2D(32, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(nb_pool[0], nb_pool[0])))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Convolution2D(nb_filters[1], nb_filters[0], nb_conv[0], nb_conv[0], border_mode='full'))
model.add(Convolution2D(64, 3, 3, border_mode='full'))
model.add(Activation('relu'))
model.add(Convolution2D(nb_filters[1], nb_filters[1], nb_conv[1], nb_conv[1]))
model.add(Convolution2D(64, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(nb_pool[1], nb_pool[1])))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
# the image dimensions are the original dimensions divided by any pooling
# each pixel has a number of filters, determined by the last Convolution2D layer
model.add(Dense(nb_filters[-1] * (shapex / nb_pool[0] / nb_pool[1]) * (shapey / nb_pool[0] / nb_pool[1]), 512))
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(512, nb_classes))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))
# let's train the model using SGD + momentum (how original).

@ -1,7 +1,7 @@
from __future__ import absolute_import
from __future__ import print_function
import numpy as np
np.random.seed(1337) # for reproducibility
np.random.seed(1337) # for reproducibility
from keras.preprocessing import sequence
from keras.optimizers import RMSprop
@ -25,7 +25,7 @@ max_features = 5000
maxlen = 100
batch_size = 32
embedding_dims = 100
nb_filters = 250
nb_filter = 250
filter_length = 3
hidden_dims = 250
nb_epoch = 3
@ -47,35 +47,29 @@ model = Sequential()
# we start off with an efficient embedding layer which maps
# our vocab indices into embedding_dims dimensions
model.add(Embedding(max_features, embedding_dims))
model.add(Embedding(max_features, embedding_dims, max_lenght=maxlen))
model.add(Dropout(0.25))
# we add a Convolution1D, which will learn nb_filters
# we add a Convolution1D, which will learn nb_filter
# word group filters of size filter_length:
model.add(Convolution1D(input_dim=embedding_dims,
nb_filter=nb_filters,
model.add(Convolution1D(nb_filter=nb_filter,
filter_length=filter_length,
border_mode="valid",
activation="relu",
subsample_length=1))
# we use standard max pooling (halving the output of the previous layer):
model.add(MaxPooling1D(pool_length=2))
# We flatten the output of the conv layer, so that we can add a vanilla dense layer:
model.add(Flatten())
# Computing the output shape of a conv layer can be tricky;
# for a good tutorial, see: http://cs231n.github.io/convolutional-networks/
output_size = nb_filters * (((maxlen - filter_length) / 1) + 1) / 2
# We add a vanilla hidden layer:
model.add(Dense(output_size, hidden_dims))
model.add(Dense(hidden_dims))
model.add(Dropout(0.25))
model.add(Activation('relu'))
# We project onto a single unit output layer, and squash it with a sigmoid:
model.add(Dense(hidden_dims, 1))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='rmsprop', class_mode="binary")

@ -49,9 +49,9 @@ print('X_test shape:', X_test.shape)
print('Build model...')
model = Sequential()
model.add(Embedding(max_features, 128))
model.add(LSTM(128, 128)) # try using a GRU instead, for fun
model.add(LSTM(128)) # try using a GRU instead, for fun
model.add(Dropout(0.5))
model.add(Dense(128, 1))
model.add(Dense(1))
model.add(Activation('sigmoid'))
# try using different optimizers and different optimizer configs

@ -20,11 +20,11 @@ from sklearn.preprocessing import StandardScaler
Compatible Python 2.7-3.4. Requires Scikit-Learn and Pandas.
Recommended to run on GPU:
Recommended to run on GPU:
Command: THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python kaggle_otto_nn.py
On EC2 g2.2xlarge instance: 19s/epoch. 6-7 minutes total training time.
Best validation score at epoch 21: 0.4881
Best validation score at epoch 21: 0.4881
Try it at home:
- with/without BatchNormalization (BatchNormalization helps!)
@ -78,7 +78,6 @@ def make_submission(y_prob, ids, encoder, fname):
f.write('\n')
print("Wrote submission to file {}.".format(fname))
print("Loading data...")
X, labels = load_data('train.csv', train=True)
X, scaler = preprocess_data(X)
@ -96,31 +95,29 @@ print(dims, 'dims')
print("Building model...")
model = Sequential()
model.add(Dense(dims, 512, init='glorot_uniform'))
model.add(PReLU((512,)))
model.add(Dense(512, input_shape=(dims,)))
model.add(PReLU())
model.add(BatchNormalization((512,)))
model.add(Dropout(0.5))
model.add(Dense(512, 512, init='glorot_uniform'))
model.add(PReLU((512,)))
model.add(BatchNormalization((512,)))
model.add(Dense(512))
model.add(PReLU())
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(512, 512, init='glorot_uniform'))
model.add(PReLU((512,)))
model.add(BatchNormalization((512,)))
model.add(Dense(512))
model.add(PReLU())
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(512, nb_classes, init='glorot_uniform'))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer="adam")
print("Training model...")
model.fit(X, y, nb_epoch=20, batch_size=128, validation_split=0.15)
print("Generating submission...")
proba = model.predict_proba(X_test)
make_submission(proba, ids, encoder, fname='keras-otto.csv')

@ -4,7 +4,8 @@ from keras.layers.core import Dense, Activation, Dropout
from keras.layers.recurrent import LSTM
from keras.datasets.data_utils import get_file
import numpy as np
import random, sys
import random
import sys
'''
Example script to generate text from Nietzsche's writings.
@ -15,7 +16,7 @@ import random, sys
It is recommended to run this script on GPU, as recurrent
networks are quite computationally intensive.
If you try this script on new data, make sure your corpus
If you try this script on new data, make sure your corpus
has at least ~100k characters. ~1M is better.
'''
@ -34,7 +35,7 @@ step = 3
sentences = []
next_chars = []
for i in range(0, len(text) - maxlen, step):
sentences.append(text[i : i + maxlen])
sentences.append(text[i: i + maxlen])
next_chars.append(text[i + maxlen])
print('nb sequences:', len(sentences))
@ -50,20 +51,21 @@ for i, sentence in enumerate(sentences):
# build the model: 2 stacked LSTM
print('Build model...')
model = Sequential()
model.add(LSTM(len(chars), 512, return_sequences=True))
model.add(LSTM(512, return_sequences=True, input_shape=(maxlen, len(chars))))
model.add(Dropout(0.2))
model.add(LSTM(512, 512, return_sequences=False))
model.add(LSTM(512, return_sequences=False))
model.add(Dropout(0.2))
model.add(Dense(512, len(chars)))
model.add(Dense(len(chars)))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
# helper function to sample an index from a probability array
def sample(a, temperature=1.0):
a = np.log(a)/temperature
a = np.exp(a)/np.sum(np.exp(a))
return np.argmax(np.random.multinomial(1,a,1))
# helper function to sample an index from a probability array
a = np.log(a) / temperature
a = np.exp(a) / np.sum(np.exp(a))
return np.argmax(np.random.multinomial(1, a, 1))
# train the model, output generated text after each iteration
for iteration in range(1, 60):
@ -79,7 +81,7 @@ for iteration in range(1, 60):
print('----- diversity:', diversity)
generated = ''
sentence = text[start_index : start_index + maxlen]
sentence = text[start_index: start_index + maxlen]
generated += sentence
print('----- Generating with seed: "' + sentence + '"')
sys.stdout.write(generated)

@ -22,20 +22,20 @@ batch_size = 128
nb_classes = 10
nb_epoch = 12
# shape of the image (SHAPE x SHAPE)
shapex, shapey = 28, 28
# input image dimensions
img_rows, img_cols = 28, 28
# number of convolutional filters to use
nb_filters = 32
# level of pooling to perform (POOL x POOL)
# size of pooling area for max pooling
nb_pool = 2
# level of convolution to perform (CONV x CONV)
# convolution kernel size
nb_conv = 3
# the data, shuffled and split between tran and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 1, shapex, shapey)
X_test = X_test.reshape(X_test.shape[0], 1, shapex, shapey)
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
X_train = X_train.astype("float32")
X_test = X_test.astype("float32")
X_train /= 255
@ -50,22 +50,20 @@ Y_test = np_utils.to_categorical(y_test, nb_classes)
model = Sequential()
model.add(Convolution2D(nb_filters, 1, nb_conv, nb_conv, border_mode='full'))
model.add(Convolution2D(nb_filters, nb_conv, nb_conv,
border_mode='full',
input_shape=(1, img_rows, img_cols)))
model.add(Activation('relu'))
model.add(Convolution2D(nb_filters, nb_filters, nb_conv, nb_conv))
model.add(Convolution2D(nb_filters, nb_conv, nb_conv))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))
model.add(Dropout(0.25))
model.add(Flatten())
# the resulting image after conv and pooling is the original shape
# divided by the pooling with a number of filters for each "pixel"
# (the number of filters is determined by the last Conv2D)
model.add(Dense(nb_filters * (shapex / nb_pool) * (shapey / nb_pool), 128))
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(128, nb_classes))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adadelta')

@ -55,11 +55,12 @@ Y_test = np_utils.to_categorical(y_test, nb_classes)
print('Evaluate IRNN...')
model = Sequential()
model.add(SimpleRNN(input_dim=1, output_dim=hidden_units,
model.add(SimpleRNN(output_dim=hidden_units,
init=lambda shape: normal(shape, scale=0.001),
inner_init=lambda shape: identity(shape, scale=1.0),
activation='relu', truncate_gradient=BPTT_truncate))
model.add(Dense(hidden_units, nb_classes))
activation='relu', truncate_gradient=BPTT_truncate,
input_shape=(None, 1)))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))
rmsprop = RMSprop(lr=learning_rate)
model.compile(loss='categorical_crossentropy', optimizer=rmsprop)
@ -73,8 +74,8 @@ print('IRNN test accuracy:', scores[1])
print('Compare to LSTM...')
model = Sequential()
model.add(LSTM(1, hidden_units))
model.add(Dense(hidden_units, nb_classes))
model.add(LSTM(hidden_units, input_shape=(None, 1)))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))
rmsprop = RMSprop(lr=learning_rate)
model.compile(loss='categorical_crossentropy', optimizer=rmsprop)

@ -37,13 +37,13 @@ Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
model = Sequential()
model.add(Dense(784, 128))
model.add(Dense(128, input_shape=(784,)))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(128, 128))
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(128, 10))
model.add(Dense(10))
model.add(Activation('softmax'))
rms = RMSprop()

@ -45,10 +45,10 @@ print('Y_test shape:', Y_test.shape)
print("Building model...")
model = Sequential()
model.add(Dense(max_words, 512))
model.add(Dense(512, input_shape=(max_words,)))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(512, nb_classes))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')

@ -160,7 +160,6 @@ word_index = tokenizer.word_index
reverse_word_index = dict([(v, k) for k, v in list(word_index.items())])
def embed_word(w):
i = word_index.get(w)
if (not i) or (i < skip_top) or (i >= max_features):

@ -25,15 +25,19 @@ class PReLU(MaskedLayer):
Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
http://arxiv.org/pdf/1502.01852v1.pdf
'''
def __init__(self, input_shape, init='zero', weights=None, **kwargs):
super(PReLU, self).__init__(**kwargs)
def __init__(self, init='zero', weights=None, **kwargs):
self.init = initializations.get(init)
self.initial_weights = weights
super(PReLU, self).__init__(**kwargs)
def build(self):
input_shape = self.input_shape[1:]
self.alphas = self.init(input_shape)
self.params = [self.alphas]
self.input_shape = input_shape
if weights is not None:
self.set_weights(weights)
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
def get_output(self, train):
X = self.get_input(train)
@ -43,7 +47,6 @@ class PReLU(MaskedLayer):
def get_config(self):
return {"name": self.__class__.__name__,
"input_shape": self.input_shape,
"init": self.init.__name__}
@ -55,19 +58,23 @@ class ParametricSoftplus(MaskedLayer):
Inferring Nonlinear Neuronal Computation Based on Physiologically Plausible Inputs
http://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1003143
'''
def __init__(self, input_shape, alpha_init=0.2,
beta_init=5.0, weights=None, **kwargs):
super(ParametricSoftplus, self).__init__(**kwargs)
def __init__(self, alpha_init=0.2, beta_init=5.0,
weights=None, **kwargs):
self.alpha_init = alpha_init
self.beta_init = beta_init
self.alphas = sharedX(alpha_init * np.ones(input_shape))
self.betas = sharedX(beta_init * np.ones(input_shape))
self.initial_weights = weights
super(ParametricSoftplus, self).__init__(**kwargs)
def build(self):
input_shape = self.input_shape[1:]
self.alphas = sharedX(self.alpha_init * np.ones(input_shape))
self.betas = sharedX(self.beta_init * np.ones(input_shape))
self.params = [self.alphas, self.betas]
self.input_shape = input_shape
if weights is not None:
self.set_weights(weights)
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
def get_output(self, train):
X = self.get_input(train)
@ -75,7 +82,6 @@ class ParametricSoftplus(MaskedLayer):
def get_config(self):
return {"name": self.__class__.__name__,
"input_shape": self.input_shape,
"alpha_init": self.alpha_init,
"beta_init": self.beta_init}

@ -51,7 +51,7 @@ def pool_output_length(input_length, pool_size, ignore_border, stride):
class Convolution1D(Layer):
input_ndim = 3
def __init__(self, input_dim, nb_filter, filter_length,
def __init__(self, nb_filter, filter_length,
init='uniform', activation='linear', weights=None,
border_mode='valid', subsample_length=1,
W_regularizer=None, b_regularizer=None, activity_regularizer=None,
@ -82,9 +82,9 @@ class Convolution1D(Layer):
def build(self):
input_dim = self.input_shape[2]
self.input = T.tensor3()
self.W_shape = (nb_filter, input_dim, filter_length, 1)
self.W_shape = (self.nb_filter, input_dim, self.filter_length, 1)
self.W = self.init(self.W_shape)
self.b = shared_zeros((nb_filter,))
self.b = shared_zeros((self.nb_filter,))
self.params = [self.W, self.b]
self.regularizers = []
@ -190,9 +190,9 @@ class Convolution2D(Layer):
def build(self):
stack_size = self.input_shape[1]
self.input = T.tensor4()
self.W_shape = (nb_filter, stack_size, nb_row, nb_col)
self.W_shape = (self.nb_filter, stack_size, self.nb_row, self.nb_col)
self.W = self.init(self.W_shape)
self.b = shared_zeros((nb_filter,))
self.b = shared_zeros((self.nb_filter,))
self.params = [self.W, self.b]
self.regularizers = []

@ -21,7 +21,8 @@ class Layer(object):
def __init__(self, **kwargs):
if 'input_shape' in kwargs:
self.set_input_shape(kwargs['input_shape'])
self.params = []
if not hasattr(self, 'params'):
self.params = []
def init_updates(self):
self.updates = []
@ -59,7 +60,7 @@ class Layer(object):
elif hasattr(self, '_input_shape'):
return self._input_shape
else:
raise Exception('Layer is not connected.')
raise Exception('Layer is not connected. Did you forget to set "input_shape"?')
def set_input_shape(self, input_shape):
if type(input_shape) not in [tuple, list]:
@ -283,7 +284,7 @@ class Merge(Layer):
elif self.mode == 'concat':
output_shape = list(input_shapes[0])
for shape in input_shapes[1:]:
output_shape[self.concat_axis] += shape[concat_axis]
output_shape[self.concat_axis] += shape[self.concat_axis]
return tuple(output_shape)
def get_params(self):
@ -528,7 +529,7 @@ class Dense(Layer):
self.input = T.matrix()
self.W = self.init((input_dim, self.output_dim))
self.b = shared_zeros((self.output_dim))
self.b = shared_zeros((self.output_dim,))
self.params = [self.W, self.b]

@ -19,38 +19,40 @@ class Embedding(Layer):
'''
input_ndim = 2
def __init__(self, input_dim, output_dim, init='uniform',
def __init__(self, input_dim, output_dim, init='uniform', max_lenght=None,
W_regularizer=None, activity_regularizer=None, W_constraint=None,
mask_zero=False, weights=None, **kwargs):
super(Embedding, self).__init__(**kwargs)
self.init = initializations.get(init)
self.input_dim = input_dim
self.output_dim = output_dim
self.input = T.imatrix()
self.W = self.init((self.input_dim, self.output_dim))
self.init = initializations.get(init)
self.max_lenght = max_lenght
self.mask_zero = mask_zero
self.params = [self.W]
self.W_constraint = constraints.get(W_constraint)
self.constraints = [self.W_constraint]
self.regularizers = []
self.W_regularizer = regularizers.get(W_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.initial_weights = weights
kwargs['input_shape'] = (self.input_dim,)
super(Embedding, self).__init__(**kwargs)
def build(self):
self.input = T.imatrix()
self.W = self.init((self.input_dim, self.output_dim))
self.params = [self.W]
self.regularizers = []
if self.W_regularizer:
self.W_regularizer.set_param(self.W)
self.regularizers.append(self.W_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
if self.activity_regularizer:
self.activity_regularizer.set_layer(self)
self.regularizers.append(self.activity_regularizer)
if weights is not None:
self.set_weights(weights)
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
def get_output_mask(self, train=None):
X = self.get_input(train)
@ -61,7 +63,7 @@ class Embedding(Layer):
@property
def output_shape(self):
return (self.input_shape[0], None, self.output_dim)
return (self.input_shape[0], self.max_lenght, self.output_dim)
def get_output(self, train=False):
X = self.get_input(train)
@ -73,6 +75,8 @@ class Embedding(Layer):
"input_dim": self.input_dim,
"output_dim": self.output_dim,
"init": self.init.__name__,
"max_lenght": self.max_lenght,
"mask_zero": self.mask_zero,
"activity_regularizer": self.activity_regularizer.get_config() if self.activity_regularizer else None,
"W_regularizer": self.W_regularizer.get_config() if self.W_regularizer else None,
"W_constraint": self.W_constraint.get_config() if self.W_constraint else None}

@ -21,6 +21,7 @@ class BatchNormalization(Layer):
self.epsilon = epsilon
self.mode = mode
self.momentum = momentum
self.initial_weights = weights
super(BatchNormalization, self).__init__(**kwargs)
def build(self):
@ -34,8 +35,9 @@ class BatchNormalization(Layer):
self.params = [self.gamma, self.beta]
self.running_mean = shared_zeros(input_shape)
self.running_std = shared_ones((input_shape))
if weights is not None:
self.set_weights(weights)
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
def get_weights(self):
return super(BatchNormalization, self).get_weights() + [self.running_mean.get_value(), self.running_std.get_value()]